diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py
index 7ed66f802f670dd912102cdc33af24f94a5fbd66..c6e4646673f4edf5872ecb3524231f4a4de9a627 100644
--- a/declearn/model/haiku/_model.py
+++ b/declearn/model/haiku/_model.py
@@ -56,6 +56,7 @@ class HaikuModel(Model):
     instance to be learned federatively.
 
     Notes regarding device management (CPU, GPU, etc.):
+
     * By default, jax places data and operations on GPU whenever one
       is available.
     * Our `HaikuModel` instead consults the device-placement policy (via
@@ -257,33 +258,56 @@ class HaikuModel(Model):
         an explicit dict of names or even the index of the parameter leaves
         stored by our HaikuModel.
 
-        Example use :
-            >>> self.get_named_weights() = {'linear': {'w': None, 'b': None}}
-        Using a function as input
-            >>> criterion = lambda layer, name, value: name == 'w'
-            >>> self.set_trainable_weights(criterion)
-            >>> self._trainable
-            [0]
-        Using a dictionnary or pytree
-            >>> criterion = {'linear': {'b': None}}
-            >>> self.set_trainable_weights(criterion)
-            >>> self._trainable
-            [1]
-
-        Note : model needs to be initialized
-
-        Arguments
-        --------
-        criterion : Callable or dict(str,dict(str,any)) or list(int)
-            Criterion to be used to identify trainable params. If Callable,
-            must be a function taking in the name of the module (e.g.
-            layer name), the element name (e.g. parameter name) and the
-            corresponding data and returning a boolean. See
-            [the haiku doc](https://tinyurl.com/3v28upaz)
-            for details. If a list of integers, should represent the index of
-            trainable  parameters in the parameter tree leaves. If a dict,
-            should be formatted as a pytree.
+        Notes
+        -----
+        - The model needs to be initialized for this method to work.
+        - The list of model weight names (general, or restricted to trainable
+          ones) may be consulted using the `get_weight_names` method.
+
+        Usage
+        -----
+
+        Let us pretend the model is made of a single linear layer; we want
+        to freeze its bias, leaving only the kernel weights trainable.
+        ```
+        >>> # Display current names of trainable model weights.
+        >>> self.get_weight_names(trainable=True)
+        ["linear/~/w", "linear/~/b"]
+        ```
+        - (A) Using a list of weight names:
+        ```
+        >>> criterion = ["linear/~/w"]
+        >>> self.set_trainable_weights(criterion)
+        ```
+        - (B) Using a function as input:
+        ```
+        >>> criterion = lambda layer, name, value: name == 'w'
+        >>> self.set_trainable_weights(criterion)
+        ```
+        - (C) Using a dictionnary or pytree:
+        ```
+        >>> criterion = {'linear': {'b': None}}
+        >>> self.set_trainable_weights(criterion)
+        ```
+        - In all three cases, we can verify the results.
+        ```
+        >>> self.get_weight_names(trainable=True)
+        ["linear/~/w"]
+        ```
 
+        Parameters
+        ----------
+        criterion: Callable or dict(str,dict(str,any)) or list(int)
+            Criterion to be used to identify trainable params.
+
+            - If a list of strings, should represent the names of weights to
+              keep as trainable (freezing each and every other one).
+            - If callable, must be a function taking in the name of the module
+              (e.g. layer name), the element name (e.g. parameter name) and the
+              corresponding data and returning a boolean.
+              See [the haiku doc](https://tinyurl.com/3v28upaz) for details.
+            - If a dict, should be formatted as a pytree, the keys of which
+              are the nodes/leaves that should remain trainable.
         """
         if not self._initialized:
             raise ValueError("Model needs to be initialized first")