fit_a_nef.utils.TrainState
- class fit_a_nef.utils.TrainState(step: int, apply_fn: Callable, params: flax.core.frozen_dict.FrozenDict[str, Any], tx: optax._src.base.GradientTransformation, opt_state: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]], rng: Any = None)
- __init__(step: int, apply_fn: Callable, params: FrozenDict[str, Any], tx: GradientTransformation, opt_state: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], rng: Any = None) None
Methods
__init__(step, apply_fn, params, tx, opt_state)apply_gradients(*, grads, **kwargs)Updates step, params, opt_state and **kwargs in return value.
create(*, apply_fn, params, tx, **kwargs)Creates a new instance with step=0 and initialized opt_state.
replace(**updates)"Returns a new object replacing the specified fields with new values.
Attributes
rngstepapply_fnparamstxopt_state- replace(**updates)
“Returns a new object replacing the specified fields with new values.