fit_a_nef.utils.TrainState

class fit_a_nef.utils.TrainState(step: int, apply_fn: Callable, params: flax.core.frozen_dict.FrozenDict[str, Any], tx: optax._src.base.GradientTransformation, opt_state: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]], rng: Any = None)
__init__(step: int, apply_fn: Callable, params: FrozenDict[str, Any], tx: GradientTransformation, opt_state: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], rng: Any = None) None

Methods

__init__(step, apply_fn, params, tx, opt_state)

apply_gradients(*, grads, **kwargs)

Updates step, params, opt_state and **kwargs in return value.

create(*, apply_fn, params, tx, **kwargs)

Creates a new instance with step=0 and initialized opt_state.

replace(**updates)

"Returns a new object replacing the specified fields with new values.

Attributes

rng

step

apply_fn

params

tx

opt_state

replace(**updates)

“Returns a new object replacing the specified fields with new values.