fit_a_nef.utils.get_optimizer
- fit_a_nef.utils.get_optimizer(optimizer_cfg: Dict[str, Any], scheduler: Callable[[Array | ndarray | bool_ | number | float | int], Array | ndarray | bool_ | number | float | int]) GradientTransformation
Get the optimizer based on the provided configuration and scheduler.
- Parameters:
optimizer_cfg (Dict[str, Any]) – Configuration for the optimizer.
scheduler (optax.Schedule) – Learning rate schedule.
- Raises:
NotImplementedError – If the specified optimizer is not implemented in optax.
- Returns:
Optimizer instance.
- Return type:
optax.GradientTransformation