fit_a_nef.tasks.shape.SignalShapeTrainer
- class fit_a_nef.tasks.shape.SignalShapeTrainer(coords: Array, occupancies: Array, train_rng: Array, nef_cfg: Dict[str, Any], scheduler_cfg: Dict[str, Any], optimizer_cfg: Dict[str, Any], log_cfg: Dict[str, Any], initializer: InitModel, num_steps: int, verbose: bool = False, num_points: Tuple[int, int] = (2048, 2048))
Class used to fit nefs on occupancy signals.
- Parameters:
coords (jnp.ndarray) – The coordinates to train on.
occupancies (jnp.ndarray) – The occupancy values to train on.
train_rng (jnp.ndarray) – The rng to use for training.
nef_cfg (Dict[str, Any]) – The config for the neural network.
scheduler_cfg (Dict[str, Any]) – The config for the scheduler.
optimizer_cfg (Dict[str, Any]) – The config for the optimizer.
log_cfg (Dict[str, Any]) – The config for the logger.
initializer (InitModel) – The initializer for the model.
num_steps (int) – The number of steps to train for.
verbose (bool, optional) – Whether to print progress, defaults to False
num_points (Tuple[int, int], optional) – The number of points to use for each shape, defaults to (2048, 2048)
- __init__(coords: Array, occupancies: Array, train_rng: Array, nef_cfg: Dict[str, Any], scheduler_cfg: Dict[str, Any], optimizer_cfg: Dict[str, Any], log_cfg: Dict[str, Any], initializer: InitModel, num_steps: int, verbose: bool = False, num_points: Tuple[int, int] = (2048, 2048))
Constructor method.
Methods
__init__(coords, occupancies, train_rng, ...)Constructor method.
apply_model(coords[, model_id])Applies the model to a given set of coordinates.
apply_model_all_coords()clean_up([clear_caches])Cleans up the trainer by deleting the state and train_step attributes.
compile()Executes the training function ones to compile the train_step.
create_functions()Creates the functions needed for training the model.
create_loss()create_train_model()Creates the train_model function.
create_train_step()Quickly trains the model for the number of steps specified in the init function.
get_flat_params([model_id])Returns the flattened params for a given model ID or all params if no model ID is specified.
get_lr()Returns the current learning rate.
get_params([model_id])Returns the params for a given model ID or all params if no model ID is specified.
init_model(example_input)Initializes the model parameters using the initializer defined in the constructor.
iou()load(path)Used to load the parameters from a hdf5 file that can be created using the save function.
process_batch(state, coords, signals)Used to process the batch before passing it to the loss function.
ram_process_batch()save(path, **kwargs)Save the parameters to a hdf5 file.
train_to_target_iou(target_iou, check_every)Trains the model for the number of steps specified in the init function and logs the loss every 100 steps.
- compile()
Executes the training function ones to compile the train_step.
- Returns:
None
- Return type:
None
- fast_train_model()
Quickly trains the model for the number of steps specified in the init function.
- init_model(example_input: Array)
Initializes the model parameters using the initializer defined in the constructor.
- Parameters:
example_input (jnp.ndarray) – An example input to the model. Used by Jax to initialize the model correctly.
- Returns:
None
- Return type:
None
- verbose_train_model()
Trains the model for the number of steps specified in the init function and logs the loss every 100 steps.