Unflattens the parameters of the model.
param_config (List[Tuple[str, List[int]]]) – Structure of the flattened parameters.
comb_params (jnp.ndarray) – The flattened parameters.
The parameters of the model.
jax.PyTree