fit_a_nef.utils.get_optimizer

fit_a_nef.utils.get_optimizer(optimizer_cfg: Dict[str, Any], scheduler: Callable[[Array | ndarray | bool_ | number | float | int], Array | ndarray | bool_ | number | float | int]) GradientTransformation

Get the optimizer based on the provided configuration and scheduler.

Parameters:
  • optimizer_cfg (Dict[str, Any]) – Configuration for the optimizer.

  • scheduler (optax.Schedule) – Learning rate schedule.

Raises:

NotImplementedError – If the specified optimizer is not implemented in optax.

Returns:

Optimizer instance.

Return type:

optax.GradientTransformation