diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py
index 63924478dbcbf8487fd82ef1bb474856a494c113..beb5c36024032d18d8b96d721b9e34843bfd01ef 100644
--- a/declearn/model/haiku/_model.py
+++ b/declearn/model/haiku/_model.py
@@ -21,7 +21,6 @@ import functools
 import inspect
 import io
 import warnings
-from copy import deepcopy
 from random import SystemRandom
 from typing import (
     # fmt: off
@@ -109,11 +108,11 @@ class HaikuModel(Model):
         # Select the device where to place computations.
         policy = get_device_policy()
         self._device = select_device(gpu=policy.gpu, idx=policy.idx)
-        # Create model state attributes
-        self._pleaves = []  # type: List[jax.Array]
-        self._treedef = None  # type: Optional[jax.tree_util.PyTreeDef]
-        self._trainable = []  # type: List[int]
-        # Initialize the PRNG
+        # Create model state attributes.
+        self._params = {}  # type: hk.Params
+        self._pnames = []  # type: List[str]
+        self._trainable = []  # type: List[str]
+        # Initialize the PRNG.
         if seed is None:
             seed = int(SystemRandom().random() * 10e6)
         self._rng_gen = hk.PRNGSequence(seed)
@@ -144,18 +143,20 @@ class HaikuModel(Model):
         self.data_info = aggregate_data_info(
             [data_info], self.required_data_info
         )
-        # initialize.
+        # Initialize the model parameters.
         params = self._transformed_model.init(
             next(self._rng_gen),
             jnp.zeros(
                 (1, *data_info["features_shape"]), data_info["data_type"]
             ),
         )
-        params = jax.device_put(params, self._device)
-        pleaves, treedef = jax.tree_util.tree_flatten(params)
-        self._treedef = treedef
-        self._pleaves = pleaves
-        self._trainable = list(range(len(pleaves)))
+        self._params = jax.device_put(params, self._device)
+        self._pnames = [
+            f"{layer}:{weight}"
+            for layer, weights in self._params.items()
+            for weight in weights
+        ]
+        self._trainable = self._pnames.copy()
         self._initialized = True
 
     def get_config(
@@ -195,51 +196,15 @@ class HaikuModel(Model):
         self,
         trainable: bool = False,
     ) -> JaxNumpyVector:
-        params = {str(k): v for k, v in enumerate(self._pleaves)}
+        params = {
+            f"{layer}:{wname}": value
+            for layer, weights in self._params.items()
+            for wname, value in weights.items()
+        }
         if trainable:
-            params = {str(idx): params[str(idx)] for idx in self._trainable}
+            params = {k: v for k, v in params.items() if k in self._trainable}
         return JaxNumpyVector(params)
 
-    def get_named_weights(
-        self,
-        trainable: bool = False,
-    ) -> Dict[str, Dict[str, jax.Array]]:
-        """Access the weights of the haiku model as a nested dict.
-
-        Return type is any due to `jax.tree_util.tree_unflatten`.
-
-        trainable: bool, default=False
-            If True, restrict the returned weights to the trainable ones,
-            else return all weights.
-        """
-        assert self._treedef is not None, "uninitialized JaxModel"
-        params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves)
-        if trainable:
-            pop_idx = set(range(len(self._pleaves))) - set(self._trainable)
-            for i, (layer, name, _) in enumerate(self._traverse_params()):
-                if i in pop_idx:
-                    params[layer].pop(name)
-                if len(params[layer]) == 0:
-                    params.pop(layer)
-        return params
-
-    def _get_fixed_named_weights(self) -> Any:
-        """Access the fixed weights of the model as a nested dict, if any.
-
-        Return type is any due to `jax.tree_util.tree_unflatten`.
-        """
-        assert self._treedef is not None, "uninitialized JaxModel"
-        if len(self._trainable) == len(self._pleaves):
-            return {}
-        params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves)
-        pop_idx = set(self._trainable)
-        for i, (layer, name, _) in enumerate(self._traverse_params()):
-            if i in pop_idx:
-                params[layer].pop(name)
-            if len(params[layer]) == 0:
-                params.pop(layer)
-        return params
-
     def set_weights(  # type: ignore  # Vector subtype specification
         self,
         weights: JaxNumpyVector,
@@ -248,16 +213,9 @@ class HaikuModel(Model):
         if not isinstance(weights, JaxNumpyVector):
             raise TypeError("HaikuModel requires JaxNumpyVector weights.")
         self._verify_weights_compatibility(weights, trainable=trainable)
-        coefs_copy = deepcopy(weights.coefs)
-        if trainable:
-            for idx in self._trainable:
-                self._pleaves[idx] = jax.device_put(
-                    coefs_copy[str(idx)], self._device
-                )
-        else:
-            self._pleaves = [
-                jax.device_put(v, self._device) for v in coefs_copy.values()
-            ]
+        for key, val in weights.coefs.items():
+            layer, weight = key.rsplit(":", 1)
+            self._params[layer][weight] = val.copy()  # type: ignore
 
     def _verify_weights_compatibility(
         self,
@@ -282,10 +240,7 @@ class HaikuModel(Model):
             are present. Be verbose about the identified mismatch(es).
         """
         received = set(vector.coefs)
-        if trainable:
-            expected = {str(i) for i in self._trainable}
-        else:
-            expected = {str(i) for i in range(len(self._pleaves))}
+        expected = set(self._trainable if trainable else self._pnames)
         raise_on_stringsets_mismatch(
             received, expected, context="model weights"
         )
@@ -295,7 +250,7 @@ class HaikuModel(Model):
         criterion: Union[
             Callable[[str, str, jax.Array], bool],
             Dict[str, Dict[str, Any]],
-            List[int],
+            List[str],
         ],
     ) -> None:
         """Sets the index of trainable weights.
@@ -335,7 +290,11 @@ class HaikuModel(Model):
         """
         if not self._initialized:
             raise ValueError("Model needs to be initialized first")
-        if isinstance(criterion, list) and isinstance(criterion[0], int):
+        if (
+            isinstance(criterion, list)
+            and all(isinstance(c, str) for c in criterion)
+            and all(c in self._pnames for c in criterion)
+        ):
             self._trainable = criterion
         else:
             self._trainable = []  # reset if needed
@@ -347,14 +306,12 @@ class HaikuModel(Model):
                 include_fn = self._build_include_fn(criterion)
             else:
                 raise TypeError(
-                    "The provided criterion does not conform"
-                    "to the expected format and or type"
+                    "The provided criterion does not conform "
+                    "to the expected format and or type."
                 )
-            for idx, (layer, name, value) in enumerate(
-                self._traverse_params()
-            ):
+            for layer, name, value in self._traverse_params():
                 if include_fn(layer, name, value):
-                    self._trainable.append(idx)
+                    self._trainable.append(f"{layer}:{name}")
 
     def _traverse_params(self) -> Iterator[Tuple[str, str, jax.Array]]:
         """Traverse the pytree of a model's named weights.
@@ -362,13 +319,11 @@ class HaikuModel(Model):
         Yield (layer_name, weight_name, weight_value) tuples from
         traversing the pytree left-to-right, depth-first.
         """
-        assert self._treedef is not None, "uninitialized JaxModel"
-        params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves)
-        for layer in params:
-            bundle = params[layer]
-            for name in bundle:
-                value = bundle[name]
-                yield layer, name, value
+        yield from (
+            (layer, weight, value)
+            for layer, weights in self._params.items()
+            for weight, value in weights.items()
+        )
 
     @staticmethod
     def _build_include_fn(
@@ -384,6 +339,57 @@ class HaikuModel(Model):
 
         return include_fn
 
+    def get_weight_names(
+        self,
+        trainable: bool = False,
+    ) -> List[str]:
+        """Return the list of names of the model's weights.
+
+        Parameters
+        ----------
+        trainable: bool
+            Whether to return only the names of trainable weights,
+            rather than including both trainable and frozen ones.
+
+        Returns
+        -------
+        names:
+            Ordered list of model weights' names.
+        """
+        return self._trainable.copy() if trainable else self._pnames.copy()
+
+    def compute_batch_gradients(
+        self,
+        batch: Batch,
+        max_norm: Optional[float] = None,
+    ) -> JaxNumpyVector:
+        # Unpack input batch and prepare model parameters.
+        inputs = self._unpack_batch(batch)
+        train_params, fixed_params = hk.data_structures.partition(
+            predicate=lambda l, w, _: f"{l}:{w}" in self._trainable,
+            structure=self._params,
+        )
+        rng = next(self._rng_gen)
+        # Compute batch-averaged gradients, opt. clipped on a per-sample basis.
+        if max_norm:
+            grads = self._clipped_grad_fn(
+                train_params, fixed_params, rng, inputs, max_norm
+            )
+            grads = [value.mean(0) for value in grads]
+        else:
+            grads = jax.tree_util.tree_leaves(
+                self._grad_fn(train_params, fixed_params, rng, inputs)
+            )
+        # Return the gradients, flattened into a JaxNumpyVector container.
+        return JaxNumpyVector(dict(zip(self._trainable, grads)))
+
+    @functools.cached_property
+    def _grad_fn(
+        self,
+    ) -> Callable[[hk.Params, hk.Params, jax.Array, JaxBatch], hk.Params]:
+        """Lazy-built jax function to compute batch-averaged gradients."""
+        return jax.jit(jax.grad(self._forward))
+
     def _forward(
         self,
         train_params: hk.Params,
@@ -399,7 +405,7 @@ class HaikuModel(Model):
             The model parameters, as a nested dict of jax arrays.
         rng: jax.Array
             A jax pseudo-random number generator (PRNG) key.
-        batch: (jax.Array, jax.Array, optional[jax.Array])
+        batch: (list[jax.Array], jax.Array, optional[jax.Array])
             Batch of jax-converted inputs, comprising (in that order)
             input data, ground-truth labels and optional sample weights.
 
@@ -408,7 +414,6 @@ class HaikuModel(Model):
         loss: jax.Array
             The mean loss over the input data provided.
         """
-        # FUTURE: add support for lists of inputs
         inputs, y_true, s_wght = batch
         params = hk.data_structures.merge(train_params, fixed_params)
         y_pred = self._transformed_model.apply(params, rng, *inputs)
@@ -417,64 +422,6 @@ class HaikuModel(Model):
             s_loss = s_loss * s_wght
         return jnp.mean(s_loss)
 
-    def compute_batch_gradients(
-        self,
-        batch: Batch,
-        max_norm: Optional[float] = None,
-    ) -> JaxNumpyVector:
-        if max_norm:
-            return self._compute_clipped_gradients(batch, max_norm)
-        return self._compute_batch_gradients(batch)
-
-    def _compute_batch_gradients(
-        self,
-        batch: Batch,
-    ) -> JaxNumpyVector:
-        """Compute and return batch-averaged gradients of trainable weights."""
-        # Unpack input batch and unflatten model parameters.
-        inputs = self._unpack_batch(batch)
-        train_params = self.get_named_weights(trainable=True)
-        fixed_params = self._get_fixed_named_weights()
-        # Run the forward and backward passes to compute gradients.
-        grads = self._grad_fn(
-            train_params,
-            fixed_params,
-            next(self._rng_gen),
-            inputs,
-        )
-        # Flatten the gradients and return them in a Vector container.
-        flat_grad, _ = jax.tree_util.tree_flatten(grads)
-        return JaxNumpyVector({str(k): v for k, v in enumerate(flat_grad)})
-
-    @functools.cached_property
-    def _grad_fn(
-        self,
-    ) -> Callable[[hk.Params, hk.Params, jax.Array, JaxBatch], hk.Params]:
-        """Lazy-built jax function to compute batch-averaged gradients."""
-        return jax.jit(jax.grad(self._forward))
-
-    def _compute_clipped_gradients(
-        self,
-        batch: Batch,
-        max_norm: float,
-    ) -> JaxNumpyVector:
-        """Compute and return sample-wise clipped, batch-averaged gradients."""
-        # Unpack input batch and unflatten model parameters.
-        inputs = self._unpack_batch(batch)
-        train_params = self.get_named_weights(trainable=True)
-        fixed_params = self._get_fixed_named_weights()
-        # Get flattened, per-sample, clipped gradients and aggregate them.
-        clipped_grads = self._clipped_grad_fn(
-            train_params,
-            fixed_params,
-            next(self._rng_gen),
-            inputs,
-            max_norm,
-        )
-        grads = [g.sum(0) for g in clipped_grads]
-        # Return them in a Vector container.
-        return JaxNumpyVector({str(k): v for k, v in enumerate(grads)})
-
     @functools.cached_property
     def _clipped_grad_fn(
         self,
@@ -484,7 +431,8 @@ class HaikuModel(Model):
         """Lazy-built jax function to compute clipped sample-wise gradients.
 
         Note : The vmap in_axis parameters work thank to the jax feature of
-        applying optional parameters to pytrees."""
+        applying optional parameters to pytrees.
+        """
 
         def clipped_grad_fn(
             train_params: hk.Params,
@@ -497,10 +445,7 @@ class HaikuModel(Model):
             inputs, y_true, s_wght = batch
             batch = (inputs, y_true, None)
             grads = jax.grad(self._forward)(
-                train_params,
-                fixed_params,
-                rng,
-                batch,
+                train_params, fixed_params, rng, batch
             )
             grads_flat = [
                 grad / jnp.maximum(jnp.linalg.norm(grad) / max_norm, 1.0)
@@ -539,8 +484,9 @@ class HaikuModel(Model):
         if not isinstance(updates, JaxNumpyVector):
             raise TypeError("HaikuModel requires JaxNumpyVector updates.")
         self._verify_weights_compatibility(updates, trainable=True)
-        for key, upd in updates.coefs.items():
-            self._pleaves[int(key)] += upd
+        for key, val in updates.coefs.items():
+            layer, weight = key.rsplit(":", 1)
+            self._params[layer][weight] += val  # type: ignore
 
     def compute_batch_predictions(
         self,
@@ -554,13 +500,14 @@ class HaikuModel(Model):
                 "correct the inputs, or override this method to support "
                 "creating labels from the base inputs."
             )
-        params = self.get_named_weights()
-        y_pred = np.asarray(
-            self._transformed_model.apply(params, next(self._rng_gen), *inputs)
+        y_pred = self._transformed_model.apply(
+            self._params, next(self._rng_gen), *inputs
+        )
+        return (
+            np.asarray(y_true),
+            np.asarray(y_pred),
+            None if s_wght is None else np.asarray(s_wght),
         )
-        y_true = np.asarray(y_true)  # type: ignore
-        s_wght = None if s_wght is None else np.asarray(s_wght)  # type: ignore
-        return y_true, y_pred, s_wght  # type: ignore
 
     def loss_function(
         self,
@@ -582,4 +529,4 @@ class HaikuModel(Model):
         # When needed, re-create the model to force moving it to the device.
         if self._device is not device:
             self._device = device
-            self._pleaves = jax.device_put(self._pleaves, self._device)
+            self._params = jax.device_put(self._params, self._device)
diff --git a/test/model/test_haiku.py b/test/model/test_haiku.py
index 8083dcda507280d5e7110490e7a07fe0f327a387..a2b8a2c6bf0dcdfca20dbaee05ad63c0379222b5 100644
--- a/test/model/test_haiku.py
+++ b/test/model/test_haiku.py
@@ -18,7 +18,7 @@
 """Unit tests for HaikuModel."""
 
 import sys
-from typing import Any, List, Literal, Tuple
+from typing import Any, Callable, Dict, List, Literal, Union
 
 import numpy as np
 import pytest
@@ -176,7 +176,8 @@ class HaikuTestCase(ModelTestCase):
             }
         )
         if self.kind == "MLP-tune":
-            model.set_trainable_weights([0, 1, 2])
+            names = model.get_weight_names()
+            model.set_trainable_weights([names[i] for i in range(3)])
         return model
 
     def assert_correct_device(
@@ -190,17 +191,24 @@ class HaikuTestCase(ModelTestCase):
             for arr in vector.coefs.values()
         )
 
-    def get_trainable_criterion(self, c_type: str):
+    def get_trainable_criterion(
+        self,
+        c_type: str,
+    ) -> Union[
+        List[str],
+        Dict[str, Dict[str, Any]],
+        Callable[[str, str, jax.Array], bool],
+    ]:
         "Build different weight freezing criteria"
-        if c_type == "indexes":
-            crit = [2, 3]
+        if c_type == "names":
+            names = self.model.get_weight_names()
+            return [names[2], names[3]]
         if c_type == "pytree":
-            params = self.model.get_named_weights()
-            params.pop(list(params.keys())[1])
-            crit = params
+            params = getattr(self.model, "_params")
+            return {k: v for i, (k, v) in enumerate(params.items()) if i != 1}
         if c_type == "predicate":
-            crit = lambda m, n, p: n != "b"  # pylint: disable=C3001
-        return crit
+            return lambda m, n, p: n != "b"
+        raise KeyError(f"Invalid 'c_type' parameter: {c_type}.")
 
 
 @pytest.fixture(name="test_case")
@@ -230,7 +238,7 @@ class TestHaikuModel(ModelTestSuite):
         super().test_serialization(test_case)
 
     @pytest.mark.parametrize(
-        "criterion_type", ["indexes", "pytree", "predicate"]
+        "criterion_type", ["names", "pytree", "predicate"]
     )
     def test_get_frozen_weights(
         self,
@@ -244,14 +252,14 @@ class TestHaikuModel(ModelTestSuite):
         w_all = model.get_weights()
         w_trn = model.get_weights(trainable=True)
         assert set(w_trn.coefs).issubset(w_all.coefs)  # check on keys
-        n_params = len(model._pleaves)  # pylint: disable=protected-access
-        n_trainable = len(model._trainable)  # pylint: disable=protected-access
+        n_params = len(model.get_weight_names())
+        n_trainable = len(model.get_weight_names(trainable=True))
         assert n_trainable < n_params
         assert len(w_trn.coefs) == n_trainable
         assert len(w_all.coefs) == n_params
 
     @pytest.mark.parametrize(
-        "criterion_type", ["indexes", "pytree", "predicate"]
+        "criterion_type", ["names", "pytree", "predicate"]
     )
     def test_set_frozen_weights(
         self,
@@ -281,7 +289,7 @@ class TestHaikuModel(ModelTestSuite):
         policy = model.device_policy
         assert policy.gpu == (test_case.device == "gpu")
         assert policy.idx == 0
-        params = getattr(model, "_pleaves")
+        params = jax.tree_util.tree_leaves(getattr(model, "_params"))
         device = f"{test_case.device}:0"
         for arr in params:
             assert f"{arr.device().platform}:{arr.device().id}" == device