fit_a_nef.utils.get_scheduler

fit_a_nef.utils.get_scheduler(scheduler_cfg: Dict[str, Any]) Callable[[Array | ndarray | bool_ | number | float | int], Array | ndarray | bool_ | number | float | int]

Returns the scheduler for the given config. All schedulers from optax are supported.

Parameters:

scheduler_cfg (ConfigDict) – The config for the scheduler.

Raises:

NotImplementedError – If the scheduler is not implemented.

Returns:

The scheduler.

Return type:

optax.Schedule