fit_a_nef.tasks.image.SignalImageTrainer
- class fit_a_nef.tasks.image.SignalImageTrainer(signals: Array, coords: Array, train_rng: Array, nef_cfg: Dict[str, Any], scheduler_cfg: Dict[str, Any], optimizer_cfg: Dict[str, Any], initializer: InitModel, log_cfg: Dict[str, Any] | None = None, num_steps: int = 500, verbose: bool = False, masked_portion: float = 0.5, images_shape: Tuple[int, int, int] | None = None, images_mean: Array | None = None, images_std: Array | None = None)
Fit a set of neural fields to a set of images, given a certain initialization method.
- Parameters:
signals (jnp.ndarray) – The images to fit to.
coords (jnp.ndarray) – The coordinates of the images.
train_rng (jnp.ndarray) – The random number generator to use.
nef_cfg (Dict[str, Any]) – The config for the neural fields.
scheduler_cfg (Dict[str, Any]) – The config for the scheduler.
optimizer_cfg (Dict[str, Any]) – The config for the optimizer.
initializer (InitModel) – The initializer to use.
log_cfg (Optional[Dict[str, Any]], optional) – The config for the logger. Defaults to None.
num_steps (int, optional) – The number of steps to train for. Defaults to 500.
verbose (bool, optional) – Whether to log the training. Defaults to False.
masked_portion (float, optional) – The portion of the image to mask. Defaults to 0.5.
images_shape (Optional[Tuple[int, int, int]], optional) – The shape of the images. Defaults to None.
images_mean (Optional[jnp.ndarray], optional) – The mean of the images. Defaults to None.
images_std (Optional[jnp.ndarray], optional) – The std of the images. Defaults to None.
- __init__(signals: Array, coords: Array, train_rng: Array, nef_cfg: Dict[str, Any], scheduler_cfg: Dict[str, Any], optimizer_cfg: Dict[str, Any], initializer: InitModel, log_cfg: Dict[str, Any] | None = None, num_steps: int = 500, verbose: bool = False, masked_portion: float = 0.5, images_shape: Tuple[int, int, int] | None = None, images_mean: Array | None = None, images_std: Array | None = None)
Constructor method.
Methods
__init__(signals, coords, train_rng, ...[, ...])Constructor method.
apply_model(coords[, model_id])Applies the model to a given set of coordinates.
clean_up([clear_caches])Cleans up the trainer by deleting the state and train_step attributes.
compile()Executes the training function ones to compile the train_step.
create_functions()Creates the functions needed for training the model.
create_loss()create_train_model()Creates the train_model function.
create_train_step()fast_train_model()Quickly trains the model for the number of steps specified in the init function.
get_flat_params([model_id])Returns the flattened params for a given model ID or all params if no model ID is specified.
get_lr()Returns the current learning rate.
get_params([model_id])Returns the params for a given model ID or all params if no model ID is specified.
init_model(example_input)Initializes the model parameters using the initializer defined in the constructor.
load(path)Used to load the parameters from a hdf5 file that can be created using the save function.
mae()Calculate the mean absolute error (MAE) between the reconstructed signals and the original signals.
mse()Calculates the mean squared error (MSE) between the reconstructed signals and the original signals.
process_batch(state, coords, images)Used to process the batch before passing it to the loss function.
psnr()Calculate the Peak Signal-to-Noise Ratio (PSNR) between the reconstructed images and the original images.
save(path, **kwargs)Save the parameters to a hdf5 file.
simse()Calculate the Structural Similarity Index (SIMSE) between the reconstructed image and the original signal.
ssim()Calculates the Structural Similarity Index (SSIM) between the reconstructed images and the original signals.
train_to_target_psnr(target_psnr, ...)Args:
Calculate the Peak Signal-to-Noise Ratio (PSNR) for the validation images.
Trains the model for the number of steps specified in the init function and logs the loss every 100 steps.
- clean_up(clear_caches=True)
Cleans up the trainer by deleting the state and train_step attributes. This is useful to free up memory.
- Parameters:
clear_caches (bool, optional) – Whether to clear the Jax caches or not. Defaults to True.
- Returns:
None
- Return type:
None
- init_model(example_input: Array)
Initializes the model parameters using the initializer defined in the constructor.
- Parameters:
example_input (jnp.ndarray) – An example input to the model. Used by Jax to initialize the model correctly.
- Returns:
None
- Return type:
None
- mae()
Calculate the mean absolute error (MAE) between the reconstructed signals and the original signals.
- Returns:
Tuple[float, float]: A tuple containing the mean of the MAE metric and the mean squared MAE metric.
- mse()
Calculates the mean squared error (MSE) between the reconstructed signals and the original signals.
- Returns:
Tuple[float, float]: A tuple containing the mean MSE and the mean squared MSE.
- process_batch(state, coords, images)
Used to process the batch before passing it to the loss function. This is useful for selecting specific coordinates or changing the shapes of the signals.
- Parameters:
state (TrainState) – the current state of the training, used for functional programming.
coords (jnp.ndarray) – the coordinates to process.
signals (jnp.ndarray) – the signals to process.
- Returns:
the processed coordinates, signals and the random number generator.
- Return type:
Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
- psnr()
Calculate the Peak Signal-to-Noise Ratio (PSNR) between the reconstructed images and the original images.
- Returns:
Tuple[float, float]: A tuple containing the mean PSNR and the mean squared PSNR.
- simse()
Calculate the Structural Similarity Index (SIMSE) between the reconstructed image and the original signal.
- Returns:
Tuple[float, float]: A tuple containing the mean SIMSE and the mean squared SIMSE.
- ssim()
Calculates the Structural Similarity Index (SSIM) between the reconstructed images and the original signals.
- Returns:
Tuple[float, float]: A tuple containing the mean SSIM and the mean squared SSIM.
- train_to_target_psnr(target_psnr, check_every, mean, std)
- Args:
target_psnr (int): The target psnr to reach. check_every (int): How often to check the psnr. mean (float): The mean of the dataset. Used in the psnr calculation. std (float): The std of the dataset. Used in the psnr calculation.
- Returns:
num_steps (int): The number of steps it took to reach the target psnr.
- validation_psnr()
Calculate the Peak Signal-to-Noise Ratio (PSNR) for the validation images.
- Returns:
Tuple[float, float]: The mean PSNR and the mean squared PSNR.
- verbose_train_model()
Trains the model for the number of steps specified in the init function and logs the loss every 100 steps.