fit_a_nef.tasks.shape.SignalShapeTrainer

class fit_a_nef.tasks.shape.SignalShapeTrainer(coords: Array, occupancies: Array, train_rng: Array, nef_cfg: Dict[str, Any], scheduler_cfg: Dict[str, Any], optimizer_cfg: Dict[str, Any], log_cfg: Dict[str, Any], initializer: InitModel, num_steps: int, verbose: bool = False, num_points: Tuple[int, int] = (2048, 2048))

Class used to fit nefs on occupancy signals.

Parameters:
  • coords (jnp.ndarray) – The coordinates to train on.

  • occupancies (jnp.ndarray) – The occupancy values to train on.

  • train_rng (jnp.ndarray) – The rng to use for training.

  • nef_cfg (Dict[str, Any]) – The config for the neural network.

  • scheduler_cfg (Dict[str, Any]) – The config for the scheduler.

  • optimizer_cfg (Dict[str, Any]) – The config for the optimizer.

  • log_cfg (Dict[str, Any]) – The config for the logger.

  • initializer (InitModel) – The initializer for the model.

  • num_steps (int) – The number of steps to train for.

  • verbose (bool, optional) – Whether to print progress, defaults to False

  • num_points (Tuple[int, int], optional) – The number of points to use for each shape, defaults to (2048, 2048)

__init__(coords: Array, occupancies: Array, train_rng: Array, nef_cfg: Dict[str, Any], scheduler_cfg: Dict[str, Any], optimizer_cfg: Dict[str, Any], log_cfg: Dict[str, Any], initializer: InitModel, num_steps: int, verbose: bool = False, num_points: Tuple[int, int] = (2048, 2048))

Constructor method.

Methods

__init__(coords, occupancies, train_rng, ...)

Constructor method.

apply_model(coords[, model_id])

Applies the model to a given set of coordinates.

apply_model_all_coords()

clean_up([clear_caches])

Cleans up the trainer by deleting the state and train_step attributes.

compile()

Executes the training function ones to compile the train_step.

create_functions()

Creates the functions needed for training the model.

create_loss()

create_train_model()

Creates the train_model function.

create_train_step()

fast_train_model()

Quickly trains the model for the number of steps specified in the init function.

get_flat_params([model_id])

Returns the flattened params for a given model ID or all params if no model ID is specified.

get_lr()

Returns the current learning rate.

get_params([model_id])

Returns the params for a given model ID or all params if no model ID is specified.

init_model(example_input)

Initializes the model parameters using the initializer defined in the constructor.

iou()

load(path)

Used to load the parameters from a hdf5 file that can be created using the save function.

process_batch(state, coords, signals)

Used to process the batch before passing it to the loss function.

ram_process_batch()

save(path, **kwargs)

Save the parameters to a hdf5 file.

train_to_target_iou(target_iou, check_every)

verbose_train_model()

Trains the model for the number of steps specified in the init function and logs the loss every 100 steps.

compile()

Executes the training function ones to compile the train_step.

Returns:

None

Return type:

None

fast_train_model()

Quickly trains the model for the number of steps specified in the init function.

init_model(example_input: Array)

Initializes the model parameters using the initializer defined in the constructor.

Parameters:

example_input (jnp.ndarray) – An example input to the model. Used by Jax to initialize the model correctly.

Returns:

None

Return type:

None

verbose_train_model()

Trains the model for the number of steps specified in the init function and logs the loss every 100 steps.