fit_a_nef.trainer.SignalTrainer

class fit_a_nef.trainer.SignalTrainer(coords: Array, signals: Array, nef_cfg: Dict[str, Any], scheduler_cfg: Dict[str, Any], optimizer_cfg: Dict[str, Any], initializer: InitModel, train_rng: Array, num_signals: int, num_steps: int = 20000, verbose: bool = False)

Base class for training neural networks.

Parameters:
  • coords (jnp.ndarray) – the coordinates to train on

  • signals (jnp.ndarray) – the signal values to train on, i.e. images (pixel values) objects (occupancy)

  • nef_cfg (Dict[str, Any]) – the config for the neural network

  • scheduler_cfg (Dict[str, Any]) – the config for the scheduler

  • optimizer_cfg (Dict[str, Any]) – the config for the optimizer

  • initializer (InitModel) – the initializer for the model, see initializers.py for more info

  • train_rng (jnp.ndarray) – the random number generator to use for training

  • num_signals (int) – the number of signals being fit

  • num_steps (int, optional) – the number of steps to train for, defaults to 20000

  • verbose (bool, optional) – whether to have verbose training or not, defaults to False. Overwrite the verbose_train_model() function to change the logging behavior.

__init__(coords: Array, signals: Array, nef_cfg: Dict[str, Any], scheduler_cfg: Dict[str, Any], optimizer_cfg: Dict[str, Any], initializer: InitModel, train_rng: Array, num_signals: int, num_steps: int = 20000, verbose: bool = False)

Constructor.

Methods

__init__(coords, signals, nef_cfg, ...[, ...])

Constructor.

apply_model(coords[, model_id])

Applies the model to a given set of coordinates.

clean_up([clear_caches])

Cleans up the trainer by deleting the state and train_step attributes.

compile()

Executes the training function ones to compile the train_step.

create_functions()

Creates the functions needed for training the model.

create_train_model()

Creates the train_model function.

create_train_step()

fast_train_model()

Quickly trains the model for the number of steps specified in the init function.

get_flat_params([model_id])

Returns the flattened params for a given model ID or all params if no model ID is specified.

get_lr()

Returns the current learning rate.

get_params([model_id])

Returns the params for a given model ID or all params if no model ID is specified.

init_model(example_input)

Initializes the model parameters using the initializer defined in the constructor.

load(path)

Used to load the parameters from a hdf5 file that can be created using the save function.

process_batch(state, coords, signals)

Used to process the batch before passing it to the loss function.

save(path, **kwargs)

Save the parameters to a hdf5 file.

verbose_train_model()

Trains the model for the number of steps specified in the init function and logs the loss every 100 steps.

apply_model(coords: Array, model_id: int | None = None) Array

Applies the model to a given set of coordinates.

Parameters:
  • coords (jnp.ndarray) – The coordinates to apply the model to.

  • model_id (int, optional) – The model ID to apply. Defaults to None in which case all models are used.

Returns:

The output of the model.

Return type:

jnp.ndarray

clean_up(clear_caches=True)

Cleans up the trainer by deleting the state and train_step attributes. This is useful to free up memory.

Parameters:

clear_caches (bool, optional) – Whether to clear the Jax caches or not. Defaults to True.

Returns:

None

Return type:

None

compile() None

Executes the training function ones to compile the train_step.

Returns:

None

Return type:

None

create_functions()

Creates the functions needed for training the model. This includes the loss function, the train step and the train model functions.

This is needed to allow for proper handling of the Jax JIT compilation of the train_step function.

For the train_model function, this allows to switch between verbose and fast training without any computational overhead during the actual fitting.

create_train_model() None

Creates the train_model function. This is used to switch between verbose and fast training without any computational overhead during the actual fitting.

Returns:

None

Return type:

None

fast_train_model() None

Quickly trains the model for the number of steps specified in the init function.

get_flat_params(model_id: int | None = None) Tuple[Array, List[Tuple[str, List[int]]]]

Returns the flattened params for a given model ID or all params if no model ID is specified.

Parameters:

model_id (int, optional) – The model ID to get the params for or None to get all params.

Returns:

A tuple with the flattened params for the given model ID or all params if no model ID is specified, and the param configuration.

Return type:

Tuple[jnp.ndarray, List[Tuple[str, List[int]]]]

get_lr()

Returns the current learning rate.

Returns:

The current learning rate.

Return type:

float

get_params(model_id: int | None = None) Array

Returns the params for a given model ID or all params if no model ID is specified.

Parameters:

model_id (int, optional) – The model ID to get the params for or None to get all params.

Returns:

The params for the given model ID or all params if no model ID is specified.

Return type:

jnp.ndarray

init_model(example_input: Array) None

Initializes the model parameters using the initializer defined in the constructor.

Parameters:

example_input (jnp.ndarray) – An example input to the model. Used by Jax to initialize the model correctly.

Returns:

None

Return type:

None

load(path: Path) None

Used to load the parameters from a hdf5 file that can be created using the save function.

Parameters:

path (Path) – The path to load the parameters from.

Returns:

None

Return type:

None

process_batch(state: TrainState, coords: Array, signals: Array) Tuple[Array, Array, Array]

Used to process the batch before passing it to the loss function. This is useful for selecting specific coordinates or changing the shapes of the signals.

Parameters:
  • state (TrainState) – the current state of the training, used for functional programming.

  • coords (jnp.ndarray) – the coordinates to process.

  • signals (jnp.ndarray) – the signals to process.

Returns:

the processed coordinates, signals and the random number generator.

Return type:

Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]

save(path: Path, **kwargs) None

Save the parameters to a hdf5 file.

Parameters:
  • path (Path) – The path to save the parameters to.

  • kwargs (Dict[str, Any]) – Additional data to save.

Returns:

None

Return type:

None

verbose_train_model() None

Trains the model for the number of steps specified in the init function and logs the loss every 100 steps.