fit_a_nef.trainer.SignalTrainer
- class fit_a_nef.trainer.SignalTrainer(coords: Array, signals: Array, nef_cfg: Dict[str, Any], scheduler_cfg: Dict[str, Any], optimizer_cfg: Dict[str, Any], initializer: InitModel, train_rng: Array, num_signals: int, num_steps: int = 20000, verbose: bool = False)
Base class for training neural networks.
- Parameters:
coords (jnp.ndarray) – the coordinates to train on
signals (jnp.ndarray) – the signal values to train on, i.e. images (pixel values) objects (occupancy)
nef_cfg (Dict[str, Any]) – the config for the neural network
scheduler_cfg (Dict[str, Any]) – the config for the scheduler
optimizer_cfg (Dict[str, Any]) – the config for the optimizer
initializer (InitModel) – the initializer for the model, see initializers.py for more info
train_rng (jnp.ndarray) – the random number generator to use for training
num_signals (int) – the number of signals being fit
num_steps (int, optional) – the number of steps to train for, defaults to 20000
verbose (bool, optional) – whether to have verbose training or not, defaults to False. Overwrite the
verbose_train_model()function to change the logging behavior.
- __init__(coords: Array, signals: Array, nef_cfg: Dict[str, Any], scheduler_cfg: Dict[str, Any], optimizer_cfg: Dict[str, Any], initializer: InitModel, train_rng: Array, num_signals: int, num_steps: int = 20000, verbose: bool = False)
Constructor.
Methods
__init__(coords, signals, nef_cfg, ...[, ...])Constructor.
apply_model(coords[, model_id])Applies the model to a given set of coordinates.
clean_up([clear_caches])Cleans up the trainer by deleting the state and train_step attributes.
compile()Executes the training function ones to compile the train_step.
Creates the functions needed for training the model.
Creates the train_model function.
create_train_step()Quickly trains the model for the number of steps specified in the init function.
get_flat_params([model_id])Returns the flattened params for a given model ID or all params if no model ID is specified.
get_lr()Returns the current learning rate.
get_params([model_id])Returns the params for a given model ID or all params if no model ID is specified.
init_model(example_input)Initializes the model parameters using the initializer defined in the constructor.
load(path)Used to load the parameters from a hdf5 file that can be created using the save function.
process_batch(state, coords, signals)Used to process the batch before passing it to the loss function.
save(path, **kwargs)Save the parameters to a hdf5 file.
Trains the model for the number of steps specified in the init function and logs the loss every 100 steps.
- apply_model(coords: Array, model_id: int | None = None) Array
Applies the model to a given set of coordinates.
- Parameters:
coords (jnp.ndarray) – The coordinates to apply the model to.
model_id (int, optional) – The model ID to apply. Defaults to None in which case all models are used.
- Returns:
The output of the model.
- Return type:
jnp.ndarray
- clean_up(clear_caches=True)
Cleans up the trainer by deleting the state and train_step attributes. This is useful to free up memory.
- Parameters:
clear_caches (bool, optional) – Whether to clear the Jax caches or not. Defaults to True.
- Returns:
None
- Return type:
None
- compile() None
Executes the training function ones to compile the train_step.
- Returns:
None
- Return type:
None
- create_functions()
Creates the functions needed for training the model. This includes the loss function, the train step and the train model functions.
This is needed to allow for proper handling of the Jax JIT compilation of the train_step function.
For the train_model function, this allows to switch between verbose and fast training without any computational overhead during the actual fitting.
- create_train_model() None
Creates the train_model function. This is used to switch between verbose and fast training without any computational overhead during the actual fitting.
- Returns:
None
- Return type:
None
- fast_train_model() None
Quickly trains the model for the number of steps specified in the init function.
- get_flat_params(model_id: int | None = None) Tuple[Array, List[Tuple[str, List[int]]]]
Returns the flattened params for a given model ID or all params if no model ID is specified.
- get_params(model_id: int | None = None) Array
Returns the params for a given model ID or all params if no model ID is specified.
- Parameters:
model_id (int, optional) – The model ID to get the params for or None to get all params.
- Returns:
The params for the given model ID or all params if no model ID is specified.
- Return type:
jnp.ndarray
- init_model(example_input: Array) None
Initializes the model parameters using the initializer defined in the constructor.
- Parameters:
example_input (jnp.ndarray) – An example input to the model. Used by Jax to initialize the model correctly.
- Returns:
None
- Return type:
None
- load(path: Path) None
Used to load the parameters from a hdf5 file that can be created using the save function.
- Parameters:
path (Path) – The path to load the parameters from.
- Returns:
None
- Return type:
None
- process_batch(state: TrainState, coords: Array, signals: Array) Tuple[Array, Array, Array]
Used to process the batch before passing it to the loss function. This is useful for selecting specific coordinates or changing the shapes of the signals.
- Parameters:
state (TrainState) – the current state of the training, used for functional programming.
coords (jnp.ndarray) – the coordinates to process.
signals (jnp.ndarray) – the signals to process.
- Returns:
the processed coordinates, signals and the random number generator.
- Return type:
Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]