fit_a_nef.utils.flatten_params
- fit_a_nef.utils.flatten_params(params: Any, num_batch_dims: int = 0, param_key: callable | None = None) Tuple[List[Tuple[str, List[int]]], Array]
Flattens the parameters of the model.
- Parameters:
params (jax.PyTree) – The parameters of the model.
num_batch_dims (int, optional) – The number of batch dimensions. Tensors will not be flattened over these dimensions, defaults to 0.
param_key (callable, optional) – The key to sort the parameters, defaults to None.
- Returns:
Structure of the flattened parameters, the flattened parameters.
- Return type: