fit_a_nef.utils.unflatten_params

fit_a_nef.utils.unflatten_params(param_config: List[Tuple[str, List[int]]], comb_params: Array)

Unflattens the parameters of the model.

Parameters:
  • param_config (List[Tuple[str, List[int]]]) – Structure of the flattened parameters.

  • comb_params (jnp.ndarray) – The flattened parameters.

Returns:

The parameters of the model.

Return type:

jax.PyTree