fit_a_nef.utils.flatten_params

fit_a_nef.utils.flatten_params(params: Any, num_batch_dims: int = 0, param_key: callable | None = None) Tuple[List[Tuple[str, List[int]]], Array]

Flattens the parameters of the model.

Parameters:
  • params (jax.PyTree) – The parameters of the model.

  • num_batch_dims (int, optional) – The number of batch dimensions. Tensors will not be flattened over these dimensions, defaults to 0.

  • param_key (callable, optional) – The key to sort the parameters, defaults to None.

Returns:

Structure of the flattened parameters, the flattened parameters.

Return type:

List[Tuple[str, List[int]]], jnp.ndarray