fit_a_nef.utils

Functions

flatten_dict(d[, separation])

Flattens a dictionary.

flatten_params(params[, num_batch_dims, ...])

Flattens the parameters of the model.

get_meta_init(storage_folder, epoch_index)

Load meta-learned initialization for current configuration.

get_nef(nef_cfg)

Returns the model for the given config.

get_optimizer(optimizer_cfg, scheduler)

Get the optimizer based on the provided configuration and scheduler.

get_scheduler(scheduler_cfg)

Returns the scheduler for the given config.

unflatten_dict(d[, separation])

Unflattens a dictionary, inverse to flatten_dict.

unflatten_params(param_config, comb_params)

Unflattens the parameters of the model.

Classes

TrainState(step, apply_fn, params, tx, opt_state)