diff --git a/declearn/model/__init__.py b/declearn/model/__init__.py
index 5a84dc83c7ad4381526c2cc8a409495c69b02313..71b604e8ad37c380324cd623c2294259ffc54327 100644
--- a/declearn/model/__init__.py
+++ b/declearn/model/__init__.py
@@ -42,6 +42,9 @@ The automatically-imported submodules implemented here are:
 Optional Submodules
 -------------------
 The optional-dependency-based submodules that may be manually imported are:
+* haiku: jax- and haiku-interfacing tools
+    - HaikuModel: Model to wrap a haiku-transformable model function.
+    - JaxNumpyVector: Vector for jax array data structures.
 
 * [tensorflow][declearn.model.tensorflow]:
     TensorFlow-interfacing tools
@@ -65,6 +68,7 @@ from . import api
 from . import sklearn
 
 OPTIONAL_MODULES = [
+    "jax",
     "tensorflow",
     "torch",
 ]
diff --git a/declearn/model/haiku/__init__.py b/declearn/model/haiku/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1adb8f1f2d0d9ec2593a3dcacc1d25bd808070cf
--- /dev/null
+++ b/declearn/model/haiku/__init__.py
@@ -0,0 +1,31 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Haiku models interfacing tools.
+
+This submodule provides with a generic interface to wrap up
+any Haiku module instance that is to be trained
+through gradient descent.
+
+This module exposes:
+* HaikuModel: Model subclass to wrap haiku.Model objects
+* JaxNumpyVector: Vector subclass to wrap jax.numpy.ndarray objects
+"""
+
+from . import utils
+from ._vector import JaxNumpyVector
+from ._model import HaikuModel
diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..17d19c87af034f41c9064dad7ac9460bf1cafb48
--- /dev/null
+++ b/declearn/model/haiku/_model.py
@@ -0,0 +1,521 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model subclass to wrap Haiku models."""
+
+import functools
+import inspect
+import io
+import warnings
+from random import SystemRandom
+from typing import (
+    # fmt: off
+    Any, Callable, Dict, List, Optional, Set, Tuple, Union
+)
+
+import haiku as hk
+import jax
+import jax.numpy as jnp
+import joblib  # type: ignore
+import numpy as np
+from typing_extensions import Self
+
+from declearn.data_info import aggregate_data_info
+from declearn.model._utils import raise_on_stringsets_mismatch
+from declearn.model.api import Model
+from declearn.model.haiku._vector import JaxNumpyVector
+from declearn.model.haiku.utils import select_device
+from declearn.typing import Batch
+from declearn.utils import DevicePolicy, get_device_policy, register_type
+
+__all__ = [
+    "HaikuModel",
+]
+
+# alias for unpacked Batch structures, converted to jax arrays
+# FUTURE: add support for lists of inputs
+JaxBatch = Tuple[List[jax.Array], Optional[jax.Array], Optional[jax.Array]]
+
+
+@register_type(name="HaikuModel", group="Model")
+class HaikuModel(Model):
+    """Model wrapper for Haiku Model instances.
+
+    This `Model` subclass is designed to wrap a `hk.Module`
+    instance to be learned federatively.
+
+    Notes regarding device management (CPU, GPU, etc.):
+    * By default, jax places data and operations on GPU whenever one
+      is available.
+    * Our `HaikuModel` instead consults the device-placement policy (via
+      `declearn.utils.get_device_policy`), places the wrapped haiku model's
+      weights there, and runs computations defined under public methods on
+      that device.
+    * Note that there is no guarantee that calling a private method directly
+      will result in abiding by that policy. Hence, be careful when writing
+      custom code, and use your own context managers to get guarantees.
+    * Note that if the global device-placement policy is updated, this will
+      only be propagated to existing instances by manually calling their
+      `update_device_policy` method.
+    * You may consult the device policy enforced by a HaikuModel
+      instance by accessing its `device_policy` property.
+    """
+
+    # pylint: disable=too-many-instance-attributes
+
+    def __init__(
+        self,
+        model: Callable[[jax.Array], jax.Array],
+        loss: Callable[[jax.Array, jax.Array], jax.Array],
+        seed: Optional[int] = None,
+    ) -> None:
+        """Instantiate a Model interface wrapping a jax-haiku model.
+
+        Parameters
+        ----------
+        model: callable(jax.Array) -> jax.Array
+            Function encapsulating a `haiku.Module` such that `model(x)`
+            returns `haiku_module(x)`, constituting a model's forward.
+        loss: callable(jax.Array, jax.Array) -> jax.Array
+            Jax-compatible function that defines the model's loss.
+            It must expect `y_pred` and `y_true` as input arguments (in
+            that order) and return sample-wise loss values.
+        seed: int or None, default=None
+            Random seed used to initialize the haiku-wrapped Pseudo-random
+            number generator. If none is provided, draw an integer between
+            0 and 10^6 using `random.SystemRandom`.
+        """
+        super().__init__(model)
+        # Assign loss module.
+        self._loss_fn = loss
+        # Get pure functions from haiku transform.
+        self._model_fn = model
+        self._transformed_model = hk.transform(model)
+        # Select the device where to place computations.
+        policy = get_device_policy()
+        self._device = select_device(gpu=policy.gpu, idx=policy.idx)
+        # Create model state attributes.
+        self._params = {}  # type: hk.Params
+        self._pnames = []  # type: List[str]
+        self._trainable = []  # type: List[str]
+        # Initialize the PRNG.
+        if seed is None:
+            seed = int(SystemRandom().random() * 10e6)
+        self._rng_gen = hk.PRNGSequence(seed)
+        # Initialized and data_info utils
+        self._initialized = False
+        self.data_info = {}  # type: Dict[str, Any]
+
+    @property
+    def device_policy(
+        self,
+    ) -> DevicePolicy:
+        device = self._device
+        return DevicePolicy(gpu=(device.platform == "gpu"), idx=device.id)
+
+    @property
+    def required_data_info(
+        self,
+    ) -> Set[str]:
+        return set() if self._initialized else {"data_type", "features_shape"}
+
+    def initialize(
+        self,
+        data_info: Dict[str, Any],
+    ) -> None:
+        if self._initialized:
+            return
+        # Check that required fields are available and of valid type.
+        self.data_info = aggregate_data_info(
+            [data_info], self.required_data_info
+        )
+        # Initialize the model parameters.
+        params = self._transformed_model.init(
+            next(self._rng_gen),
+            jnp.zeros(
+                (1, *data_info["features_shape"]), data_info["data_type"]
+            ),
+        )
+        self._params = jax.device_put(params, self._device)
+        self._pnames = [
+            f"{layer}:{weight}"
+            for layer, weights in self._params.items()
+            for weight in weights
+        ]
+        self._trainable = self._pnames.copy()
+        self._initialized = True
+
+    def get_config(
+        self,
+    ) -> Dict[str, Any]:
+        warnings.warn(
+            "Our custom Haiku serialization relies on pickle,"
+            "which may be unsafe."
+        )
+        with io.BytesIO() as buffer:
+            joblib.dump(self._model_fn, buffer)
+            model = buffer.getbuffer().hex()
+        with io.BytesIO() as buffer:
+            joblib.dump(self._loss_fn, buffer)
+            loss = buffer.getbuffer().hex()
+        return {
+            "model": model,
+            "loss": loss,
+            "data_info": self.data_info,
+        }
+
+    @classmethod
+    def from_config(
+        cls,
+        config: Dict[str, Any],
+    ) -> Self:
+        with io.BytesIO(bytes.fromhex(config["model"])) as buffer:
+            model = joblib.load(buffer)
+        with io.BytesIO(bytes.fromhex(config["loss"])) as buffer:
+            loss = joblib.load(buffer)
+        model = cls(model=model, loss=loss)
+        if config.get("data_info"):
+            model.initialize(config["data_info"])
+        return model
+
+    def get_weights(
+        self,
+        trainable: bool = False,
+    ) -> JaxNumpyVector:
+        params = {
+            f"{layer}:{wname}": value
+            for layer, weights in self._params.items()
+            for wname, value in weights.items()
+        }
+        if trainable:
+            params = {k: v for k, v in params.items() if k in self._trainable}
+        return JaxNumpyVector(params)
+
+    def set_weights(  # type: ignore  # Vector subtype specification
+        self,
+        weights: JaxNumpyVector,
+        trainable: bool = False,
+    ) -> None:
+        if not isinstance(weights, JaxNumpyVector):
+            raise TypeError("HaikuModel requires JaxNumpyVector weights.")
+        self._verify_weights_compatibility(weights, trainable=trainable)
+        for key, val in weights.coefs.items():
+            layer, weight = key.rsplit(":", 1)
+            self._params[layer][weight] = val.copy()  # type: ignore
+
+    def _verify_weights_compatibility(
+        self,
+        vector: JaxNumpyVector,
+        trainable: bool = False,
+    ) -> None:
+        """Verify that a vector has the same names as the model's weights.
+
+        Parameters
+        ----------
+        vector: JaxNumpyVector
+            Vector wrapping weight-related coefficients (e.g. weight
+            values or gradient-based updates).
+        trainable: bool, default=False
+            Whether to restrict the comparision to the model's trainable
+            weights rather than to all of its weights.
+
+        Raises
+        ------
+        KeyError:
+            In case some expected keys are missing, or additional keys
+            are present. Be verbose about the identified mismatch(es).
+        """
+        received = set(vector.coefs)
+        expected = set(self._trainable if trainable else self._pnames)
+        raise_on_stringsets_mismatch(
+            received, expected, context="model weights"
+        )
+
+    def set_trainable_weights(
+        self,
+        criterion: Union[
+            Callable[[str, str, jax.Array], bool],
+            Dict[str, Dict[str, Any]],
+            List[str],
+        ],
+    ) -> None:
+        """Sets the index of trainable weights.
+
+        The split can be done by providing a functions applying conditions on
+        the named weights, as haiku users are used to do, but can also accept
+        an explicit dict of names or even the index of the parameter leaves
+        stored by our HaikuModel.
+
+        Example use :
+            >>> self.get_named_weights() = {'linear': {'w': None, 'b': None}}
+        Using a function as input
+            >>> criterion = lambda layer, name, value: name == 'w'
+            >>> self.set_trainable_weights(criterion)
+            >>> self._trainable
+            [0]
+        Using a dictionnary or pytree
+            >>> criterion = {'linear': {'b': None}}
+            >>> self.set_trainable_weights(criterion)
+            >>> self._trainable
+            [1]
+
+        Note : model needs to be initialized
+
+        Arguments
+        --------
+        criterion : Callable or dict(str,dict(str,any)) or list(int)
+            Criterion to be used to identify trainable params. If Callable,
+            must be a function taking in the name of the module (e.g.
+            layer name), the element name (e.g. parameter name) and the
+            corresponding data and returning a boolean. See
+            [the haiku doc](https://tinyurl.com/3v28upaz)
+            for details. If a list of integers, should represent the index of
+            trainable  parameters in the parameter tree leaves. If a dict,
+            should be formatted as a pytree.
+
+        """
+        if not self._initialized:
+            raise ValueError("Model needs to be initialized first")
+        if (
+            isinstance(criterion, list)
+            and all(isinstance(c, str) for c in criterion)
+            and all(c in self._pnames for c in criterion)
+        ):
+            self._trainable = criterion
+        else:
+            self._trainable = []  # reset if needed
+            if inspect.isfunction(criterion):
+                include_fn = (
+                    criterion
+                )  # type: Callable[[str, str, jax.Array], bool]
+            elif isinstance(criterion, dict):
+                include_fn = self._build_include_fn(criterion)
+            else:
+                raise TypeError(
+                    "The provided criterion does not conform "
+                    "to the expected format and or type."
+                )
+            gen = hk.data_structures.traverse(self._params)
+            for layer, name, value in gen:
+                if include_fn(layer, name, value):
+                    self._trainable.append(f"{layer}:{name}")
+
+    @staticmethod
+    def _build_include_fn(
+        node_names: Dict[str, Dict[str, Any]],
+    ) -> Callable[[str, str, jax.Array], bool]:
+        """Build an equality-checking function for parameters' traversal."""
+
+        def include_fn(layer: str, name: str, value: jax.Array) -> bool:
+            # mandatory signature; pylint: disable=unused-argument
+            if layer in list(node_names.keys()):
+                return name in list(node_names[layer].keys())
+            return False
+
+        return include_fn
+
+    def get_weight_names(
+        self,
+        trainable: bool = False,
+    ) -> List[str]:
+        """Return the list of names of the model's weights.
+
+        Parameters
+        ----------
+        trainable: bool
+            Whether to return only the names of trainable weights,
+            rather than including both trainable and frozen ones.
+
+        Returns
+        -------
+        names:
+            Ordered list of model weights' names.
+        """
+        return self._trainable.copy() if trainable else self._pnames.copy()
+
+    def compute_batch_gradients(
+        self,
+        batch: Batch,
+        max_norm: Optional[float] = None,
+    ) -> JaxNumpyVector:
+        # Unpack input batch and prepare model parameters.
+        inputs = self._unpack_batch(batch)
+        train_params, fixed_params = hk.data_structures.partition(
+            predicate=lambda l, w, _: f"{l}:{w}" in self._trainable,
+            structure=self._params,
+        )
+        rng = next(self._rng_gen)
+        # Compute batch-averaged gradients, opt. clipped on a per-sample basis.
+        if max_norm:
+            grads = self._clipped_grad_fn(
+                train_params, fixed_params, rng, inputs, max_norm
+            )
+            grads = [value.mean(0) for value in grads]
+        else:
+            grads = jax.tree_util.tree_leaves(
+                self._grad_fn(train_params, fixed_params, rng, inputs)
+            )
+        # Return the gradients, flattened into a JaxNumpyVector container.
+        return JaxNumpyVector(dict(zip(self._trainable, grads)))
+
+    @functools.cached_property
+    def _grad_fn(
+        self,
+    ) -> Callable[[hk.Params, hk.Params, jax.Array, JaxBatch], hk.Params]:
+        """Lazy-built jax function to compute batch-averaged gradients."""
+        return jax.jit(jax.grad(self._forward))
+
+    def _forward(
+        self,
+        train_params: hk.Params,
+        fixed_params: hk.Params,
+        rng: jax.Array,
+        batch: JaxBatch,
+    ) -> jax.Array:
+        """The forward pass chaining the model to the loss as a pure function.
+
+        Parameters
+        -------
+        params: haiku.Params
+            The model parameters, as a nested dict of jax arrays.
+        rng: jax.Array
+            A jax pseudo-random number generator (PRNG) key.
+        batch: (list[jax.Array], jax.Array, optional[jax.Array])
+            Batch of jax-converted inputs, comprising (in that order)
+            input data, ground-truth labels and optional sample weights.
+
+        Returns
+        -------
+        loss: jax.Array
+            The mean loss over the input data provided.
+        """
+        inputs, y_true, s_wght = batch
+        params = hk.data_structures.merge(train_params, fixed_params)
+        y_pred = self._transformed_model.apply(params, rng, *inputs)
+        s_loss = self._loss_fn(y_pred, y_true)  # type: ignore
+        if s_wght is not None:
+            s_loss = s_loss * s_wght
+        return jnp.mean(s_loss)
+
+    @functools.cached_property
+    def _clipped_grad_fn(
+        self,
+    ) -> Callable[
+        [hk.Params, hk.Params, jax.Array, JaxBatch, float], List[jax.Array]
+    ]:
+        """Lazy-built jax function to compute clipped sample-wise gradients.
+
+        Note : The vmap in_axis parameters work thank to the jax feature of
+        applying optional parameters to pytrees.
+        """
+
+        def clipped_grad_fn(
+            train_params: hk.Params,
+            fixed_params: hk.Params,
+            rng: jax.Array,
+            batch: JaxBatch,
+            max_norm: float,
+        ) -> List[jax.Array]:
+            """Compute and clip gradients wrt parameters for a sample."""
+            inputs, y_true, s_wght = batch
+            batch = (inputs, y_true, None)
+            grads = jax.grad(self._forward)(
+                train_params, fixed_params, rng, batch
+            )
+            grads_flat = [
+                grad / jnp.maximum(jnp.linalg.norm(grad) / max_norm, 1.0)
+                for grad in jax.tree_util.tree_leaves(grads)
+            ]
+            if s_wght is not None:
+                grads_flat = [g * s_wght for g in grads_flat]
+            return grads_flat
+
+        in_axes = [None, None, None, 0, None]  # map on inputs' first dimension
+        return jax.jit(jax.vmap(clipped_grad_fn, in_axes))
+
+    @staticmethod
+    def _unpack_batch(batch: Batch) -> JaxBatch:
+        """Unpack and enforce jnp.array conversion to an input data batch."""
+
+        def convert(data: Any) -> Optional[jax.Array]:
+            if (data is None) or isinstance(data, jax.Array):
+                return data
+            if isinstance(data, np.ndarray):
+                return jnp.array(data)  # pylint: disable=no-member
+            raise TypeError("HaikuModel requires numpy or jax.numpy data.")
+
+        # similar code to TorchModel; pylint: disable=duplicate-code
+        # Convert batched data to jax Arrays.
+        inputs, y_true, s_wght = batch
+        if not isinstance(inputs, (tuple, list)):
+            inputs = [inputs]
+        output = [list(map(convert, inputs)), convert(y_true), convert(s_wght)]
+        return output  # type: ignore
+
+    def apply_updates(  # type: ignore  # Vector subtype specification
+        self,
+        updates: JaxNumpyVector,
+    ) -> None:
+        if not isinstance(updates, JaxNumpyVector):
+            raise TypeError("HaikuModel requires JaxNumpyVector updates.")
+        self._verify_weights_compatibility(updates, trainable=True)
+        for key, val in updates.coefs.items():
+            layer, weight = key.rsplit(":", 1)
+            self._params[layer][weight] += val  # type: ignore
+
+    def compute_batch_predictions(
+        self,
+        batch: Batch,
+    ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
+        inputs, y_true, s_wght = self._unpack_batch(batch)
+        if y_true is None:
+            raise TypeError(
+                "`HaikuModel.compute_batch_predictions` received a "
+                "batch with `y_true=None`, which is unsupported. Please "
+                "correct the inputs, or override this method to support "
+                "creating labels from the base inputs."
+            )
+        y_pred = self._transformed_model.apply(
+            self._params, next(self._rng_gen), *inputs
+        )
+        return (
+            np.asarray(y_true),
+            np.asarray(y_pred),
+            None if s_wght is None else np.asarray(s_wght),
+        )
+
+    def loss_function(
+        self,
+        y_true: np.ndarray,
+        y_pred: np.ndarray,
+    ) -> np.ndarray:
+        s_loss = self._loss_fn(jnp.array(y_pred), jnp.array(y_true))
+        return np.array(s_loss).squeeze()
+
+    def update_device_policy(
+        self,
+        policy: Optional[DevicePolicy] = None,
+    ) -> None:
+        # similar code to tensorflow Model; pylint: disable=duplicate-code
+        # Select the device to use based on the provided or global policy.
+        if policy is None:
+            policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        # When needed, re-create the model to force moving it to the device.
+        if self._device is not device:
+            self._device = device
+            self._params = jax.device_put(self._params, self._device)
diff --git a/declearn/model/haiku/_vector.py b/declearn/model/haiku/_vector.py
new file mode 100644
index 0000000000000000000000000000000000000000..74ac8cafe77c12898d9e56b1da7166da51163a65
--- /dev/null
+++ b/declearn/model/haiku/_vector.py
@@ -0,0 +1,170 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""JaxNumpyVector data arrays container."""
+
+from typing import Any, Callable, Dict, Optional, Set, Type
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from jax.config import config as jaxconfig
+from typing_extensions import Self  # future: import from typing (Py>=3.11)
+
+from declearn.model.api import Vector, register_vector_type
+from declearn.model.haiku.utils import select_device
+from declearn.model.sklearn import NumpyVector
+from declearn.utils import get_device_policy
+
+__all__ = [
+    "JaxNumpyVector",
+]
+
+
+jaxconfig.update("jax_enable_x64", True)  # enable float64 support
+
+
+@register_vector_type(jax.Array)
+class JaxNumpyVector(Vector):
+    """Vector subclass to store jax.numpy.ndarray coefficients.
+
+    This Vector is designed to store a collection of named
+    jax numpy arrays or scalars, enabling computations that are
+    either applied to each and every coefficient, or imply
+    two sets of aligned coefficients (i.e. two JaxNumpyVector
+    instances with similar coefficients specifications).
+
+    Use `vector.coefs` to access the stored coefficients.
+
+    Notes
+    -----
+    - A `JaxnumpyVector` can be operated with either a:
+      - scalar value
+      - `NumpyVector` that has similar specifications
+      - `JaxNumpyVector` that has similar specifications
+      => resulting in a `JaxNumpyVector` in each of these cases.
+    - The wrapped arrays may be placed on any device (CPU, GPU...)
+      and may not be all on the same device.
+    - The device-placement of the initial `JaxNumpyVector`'s data
+      is preserved by operations, including with `NumpyVector`.
+    - When combining two `JaxNumpyVector`, the device-placement
+      of the left-most one is used; in that case, one ends up with
+      `gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a
+      warning will be emitted to prevent silent un-optimized copies.
+    - When deserializing a `JaxNumpyVector` (either by directly using
+      `JaxNumpyVector.unpack` or loading one from a JSON dump), loaded
+      arrays are placed based on the global device-placement policy
+      (accessed via `declearn.utils.get_device_policy`). Thus it may
+      have a different device-placement schema than at dump time but
+      should be coherent with that of `HaikuModel` computations.
+    """
+
+    @property
+    def _op_add(self) -> Callable[[Any, Any], Any]:
+        return jnp.add
+
+    @property
+    def _op_sub(self) -> Callable[[Any, Any], Any]:
+        return jnp.subtract
+
+    @property
+    def _op_mul(self) -> Callable[[Any, Any], Any]:
+        return jnp.multiply
+
+    @property
+    def _op_div(self) -> Callable[[Any, Any], Any]:
+        return jnp.divide
+
+    @property
+    def _op_pow(self) -> Callable[[Any, Any], Any]:
+        return jnp.power
+
+    @property
+    def compatible_vector_types(self) -> Set[Type[Vector]]:
+        types = super().compatible_vector_types
+        return types.union({NumpyVector, JaxNumpyVector})
+
+    def __init__(self, coefs: Dict[str, jax.Array]) -> None:
+        super().__init__(coefs)
+
+    def _apply_operation(
+        self,
+        other: Any,
+        func: Callable[[Any, Any], Any],
+    ) -> Self:
+        # Ensure 'other' JaxNumpyVector shares this vector's device placement.
+        if isinstance(other, JaxNumpyVector):
+            coefs = {
+                key: jax.device_put(val, self.coefs[key].device())
+                for key, val in other.coefs.items()
+            }
+            other = JaxNumpyVector(coefs)
+        return super()._apply_operation(other, func)
+
+    def __eq__(self, other: Any) -> bool:
+        valid = isinstance(other, JaxNumpyVector)
+        valid = valid and (self.coefs.keys() == other.coefs.keys())
+        return valid and all(
+            jnp.array_equal(self.coefs[k], other.coefs[k]) for k in self.coefs
+        )
+
+    def sign(
+        self,
+    ) -> Self:
+        return self.apply_func(jnp.sign)
+
+    def minimum(
+        self,
+        other: Any,
+    ) -> Self:
+        if isinstance(other, JaxNumpyVector):
+            return self._apply_operation(other, jnp.minimum)
+        return self.apply_func(jnp.minimum, other)
+
+    def maximum(
+        self,
+        other: Any,
+    ) -> Self:
+        if isinstance(other, Vector):
+            return self._apply_operation(other, jnp.maximum)
+        return self.apply_func(jnp.maximum, other)
+
+    def sum(
+        self,
+        axis: Optional[int] = None,
+        keepdims: bool = False,
+    ) -> Self:
+        coefs = {
+            key: jnp.array(jnp.sum(val, axis=axis, keepdims=keepdims))
+            for key, val in self.coefs.items()
+        }
+        return self.__class__(coefs)
+
+    def pack(
+        self,
+    ) -> Dict[str, Any]:
+        return {key: np.asarray(arr) for key, arr in self.coefs.items()}
+
+    @classmethod
+    def unpack(
+        cls,
+        data: Dict[str, Any],
+    ) -> Self:
+        policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        coefs = {k: jax.device_put(arr, device) for k, arr in data.items()}
+        return cls(coefs)
diff --git a/declearn/model/haiku/utils/__init__.py b/declearn/model/haiku/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ce83acaf4aaaed40a31a1713fc967d190abddc2
--- /dev/null
+++ b/declearn/model/haiku/utils/__init__.py
@@ -0,0 +1,25 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for jax backend support code.
+
+GPU/CPU backing device management utils:
+* select_device:
+    Select a backing device to use based on inputs and availability.
+"""
+
+from ._gpu import select_device
diff --git a/declearn/model/haiku/utils/_gpu.py b/declearn/model/haiku/utils/_gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c79288e8d234c480e2f839ab026f7181a572aac
--- /dev/null
+++ b/declearn/model/haiku/utils/_gpu.py
@@ -0,0 +1,76 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for GPU support and device management in jax."""
+
+import warnings
+from typing import Optional
+
+import jax
+import jaxlib.xla_extension as xe
+
+__all__ = ["select_device"]
+
+
+def select_device(
+    gpu: bool,
+    idx: Optional[int] = None,
+) -> xe.Device:  # pylint: disable=c-extension-no-member
+    """Select a backing device to use based on inputs and availability.
+
+    Parameters
+    ----------
+    gpu: bool
+        Whether to select a GPU device rather than the CPU one.
+    idx: int or None, default=None
+        Optional pre-selected device index. Only used when `gpu=True`.
+        If `idx is None` or exceeds the number of available GPU devices,
+        use the first available one.
+
+    Warns
+    -----
+    UserWarning:
+        If `gpu=True` but no GPU is available.
+        If `idx` exceeds the number of available GPU devices.
+
+    Returns
+    -------
+    device: jaxlib.xla_extension.Device
+        Selected device.
+    """
+    idx = 0 if idx is None else idx
+    # List available CPU or GPU devices.
+    device_type = "gpu" if gpu else "cpu"
+    devices = [d for d in jax.devices() if d.platform == device_type]
+    # Case when no GPU is available: warn and use a CPU instead.
+    if gpu and not devices:
+        warnings.warn(
+            "Cannot use a GPU device: either CUDA is unavailable "
+            "or no GPU is visible to jax."
+        )
+        device_type, idx = "cpu", 0
+        devices = jax.devices(device_type)
+    # similar code to tensorflow util; pylint: disable=duplicate-code
+    # Case when the desired device index is invalid: select another one.
+    if idx >= len(devices):
+        warnings.warn(
+            f"Cannot use {device_type} device n°{idx}: index is out-of-range."
+            f"\nUsing {device_type} device n°0 instead."
+        )
+        idx = 0
+    # Return the selected device.
+    return devices[idx]
diff --git a/declearn/model/torch/__init__.py b/declearn/model/torch/__init__.py
index 1f2618b5ab095c45b3fb2b247d6e0b336a970611..ff0f3cead89b7f6755cb80d87814d23a36b67362 100644
--- a/declearn/model/torch/__init__.py
+++ b/declearn/model/torch/__init__.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""Tensorflow models interfacing tools.
+"""Pytorch models interfacing tools.
 
 This submodule provides with a generic interface to wrap up any PyTorch
 `torch.nn.Module` instance that is to be trained with gradient descent.
diff --git a/declearn/test_utils/_vectors.py b/declearn/test_utils/_vectors.py
index fdc53dd79d6161b8de6f7d0d0ed004b3954ecea0..e0f37b873848748fc87b22b7f23f62d2cf9361ed 100644
--- a/declearn/test_utils/_vectors.py
+++ b/declearn/test_utils/_vectors.py
@@ -35,7 +35,7 @@ __all__ = [
 ]
 
 
-FrameworkType = Literal["numpy", "tensorflow", "torch"]
+FrameworkType = Literal["numpy", "tensorflow", "torch", "jax"]
 
 
 def list_available_frameworks() -> List[FrameworkType]:
@@ -81,6 +81,9 @@ class GradientsTestCase:
         if self.framework == "torch":
             module = importlib.import_module("declearn.model.torch")
             return module.TorchVector
+        if self.framework == "jax":
+            module = importlib.import_module("declearn.model.haiku")
+            return module.JaxNumpyVector
         raise ValueError(f"Invalid framework '{self.framework}'")
 
     def convert(self, array: np.ndarray) -> ArrayLike:
@@ -94,12 +97,17 @@ class GradientsTestCase:
         if self.framework == "torch":
             torch = importlib.import_module("torch")
             return torch.from_numpy(array)
+        if self.framework == "jax":
+            jnp = importlib.import_module("jax.numpy")
+            return jnp.asarray(array)
         raise ValueError(f"Invalid framework '{self.framework}'")
 
     def to_numpy(self, array: ArrayLike) -> np.ndarray:
         """Convert an input framework-based structure to a numpy array."""
         if isinstance(array, np.ndarray):
             return array
+        if self.framework == "jax":
+            return np.asarray(array)
         if self.framework == "tensorflow":  # add support for IndexedSlices
             tensorflow = importlib.import_module("tensorflow")
             if isinstance(array, tensorflow.IndexedSlices):
diff --git a/pyproject.toml b/pyproject.toml
index 70ca5ffd51716408a94016d3818cbb7610ae0c45..e486430064c298cb3162945bd694a790e711753e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,10 +46,12 @@ dependencies = [
 ]
 
 [project.optional-dependencies]
-# all non-docs, non-tests extra dependencies
-all = [
+all = [  # all non-tests extra dependencies
+    "dm-haiku == 0.0.9",
+    "jax == 0.4.4",
     "functorch",
     "grpcio >= 1.45",
+    "jax[cpu] == 0.4.4",
     "opacus ~= 1.1",
     "protobuf >= 3.19",
     "tensorflow ~= 2.5",
@@ -64,6 +66,10 @@ grpc = [
     "grpcio >= 1.45",
     "protobuf >= 3.19",
 ]
+haiku = [
+    "jax == 0.4.4",
+    "dm-haiku == 0.0.9",
+]
 tensorflow = [
     "tensorflow ~= 2.5",
 ]
diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py
index e065e364f4618f0b20ceaf000c3fb8ca759460aa..66f41ecd47605246e6a477980de65ba0a319399e 100644
--- a/test/functional/test_regression.py
+++ b/test/functional/test_regression.py
@@ -62,24 +62,44 @@ from declearn.test_utils import FrameworkType
 from declearn.utils import run_as_processes
 from declearn.utils import set_device_policy
 
-# pylint: disable=ungrouped-imports; optional frameworks' dependencies
+# optional frameworks' dependencies pylint: disable=ungrouped-imports
+# pylint: disable=duplicate-code
+# tensorflow imports
 try:
     import tensorflow as tf  # type: ignore
-    from declearn.model.tensorflow import TensorflowModel
 except ModuleNotFoundError:
     pass
+else:
+    from declearn.model.tensorflow import TensorflowModel
+# torch imports
 try:
     import torch
+except ModuleNotFoundError:
+    pass
+else:
     from declearn.model.torch import TorchModel
+# pylint: enable=duplicate-code
+# haiku and jax imports
+try:
+    import haiku as hk
+    import jax
 except ModuleNotFoundError:
     pass
+else:
+    from declearn.model.haiku import HaikuModel
+
+    def haiku_model_fn(inputs: jax.Array) -> jax.Array:
+        """Simple linear model implemented with Haiku."""
+        return hk.Linear(1)(inputs)
+
+    def haiku_loss_fn(y_pred: jax.Array, y_true: jax.Array) -> jax.Array:
+        """Sample-wise squared error loss function."""
+        return (y_pred - y_true) ** 2
 
 
 SEED = 0
 R2_THRESHOLD = 0.999
 
-# pylint: disable=too-many-function-args
-
 
 def get_model(framework: FrameworkType) -> Model:
     """Set up a simple toy regression model."""
@@ -101,6 +121,8 @@ def get_model(framework: FrameworkType) -> Model:
             torch.nn.Flatten(0),
         )
         model = TorchModel(torchmod, loss=torch.nn.MSELoss())
+    elif framework == "jax":
+        model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn)
     else:
         raise ValueError("unrecognised framework")
     return model
@@ -245,7 +267,9 @@ def _server_routine(
     # pylint: disable=too-many-arguments
     # Set up the FederatedServer.
     model = get_model(framework)
-    netwk = NetworkServerConfig("websockets", "127.0.0.1", 8765)
+    netwk = NetworkServerConfig.from_params(
+        protocol="websockets", host="127.0.0.1", port=8765
+    )
     optim = FLOptimConfig.from_params(
         aggregator="averaging",
         client_opt={
@@ -254,7 +278,6 @@ def _server_routine(
         },
         server_opt=1.0,
     )
-
     server = FederatedServer(
         model,
         netwk,
@@ -281,8 +304,9 @@ def _client_routine(
     name: str = "client",
 ) -> None:
     """Routine to run a FL client, called by `run_declearn_experiment`."""
-    # Run the declearn FL client routine.
-    netwk = NetworkClientConfig("websockets", "ws://localhost:8765", name)
+    netwk = NetworkClientConfig.from_params(
+        protocol="websockets", server_uri="ws://localhost:8765", name=name
+    )
     client = FederatedClient(netwk, train, valid)
     client.run()
 
diff --git a/test/model/test_haiku.py b/test/model/test_haiku.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2b8a2c6bf0dcdfca20dbaee05ad63c0379222b5
--- /dev/null
+++ b/test/model/test_haiku.py
@@ -0,0 +1,295 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for HaikuModel."""
+
+import sys
+from typing import Any, Callable, Dict, List, Literal, Union
+
+import numpy as np
+import pytest
+
+try:
+    import haiku as hk
+    import jax
+    import jax.numpy as jnp
+    from jax.config import config as jaxconfig
+except ModuleNotFoundError:
+    pytest.skip("jax and/or haiku are unavailable", allow_module_level=True)
+
+from declearn.model.haiku import HaikuModel, JaxNumpyVector
+from declearn.typing import Batch
+from declearn.utils import set_device_policy
+
+# dirty trick to import from `model_testing.py`;
+# pylint: disable=wrong-import-order, wrong-import-position
+sys.path.append(".")
+from model_testing import ModelTestCase, ModelTestSuite
+
+# Overriding float32 default in jax
+jaxconfig.update("jax_enable_x64", True)
+
+
+def cnn_fn(inputs: jax.Array) -> jax.Array:
+    """Simple CNN in a purely functional form"""
+    stack = [
+        hk.Conv2D(output_channels=32, kernel_shape=(7, 7), padding="SAME"),
+        jax.nn.relu,
+        hk.MaxPool(window_shape=(8, 8, 1), strides=(8, 8, 1), padding="VALID"),
+        hk.Conv2D(output_channels=16, kernel_shape=(5, 5), padding="SAME"),
+        jax.nn.relu,
+        hk.AvgPool(window_shape=(8, 8, 1), strides=(8, 8, 1), padding="VALID"),
+        hk.Reshape((16,)),
+        hk.Linear(1),
+    ]
+    model = hk.Sequential(stack)  # type: ignore
+    return model(inputs)
+
+
+def mlp_fn(inputs: jax.Array) -> jax.Array:
+    """Simple MLP in a purely functional form"""
+    model = hk.nets.MLP([32, 16, 1])
+    return model(inputs)
+
+
+def rnn_fn(inputs: jax.Array) -> jax.Array:
+    """Simple RNN in a purely functional form"""
+    inputs = inputs[None, :] if len(inputs.shape) == 1 else inputs
+    core = hk.DeepRNN(
+        [
+            hk.Embed(100, 32),
+            hk.LSTM(32),
+            jax.nn.tanh,
+        ]
+    )
+    batch_size = inputs.shape[0]
+    initial_state = core.initial_state(batch_size)
+    logits, _ = hk.dynamic_unroll(
+        core, inputs, initial_state, time_major=False
+    )
+    return hk.Linear(1)(logits)[:, -1, :]
+
+
+def loss_fn(y_pred: jax.Array, y_true: jax.Array) -> jax.Array:
+    """Per-sample binary cross entropy"""
+    y_pred = jax.nn.sigmoid(y_pred)
+    y_pred = jnp.squeeze(y_pred)
+    log_p, log_not_p = jnp.log(y_pred), jnp.log(1.0 - y_pred)
+    return -y_true * log_p - (1.0 - y_true) * log_not_p
+
+
+class HaikuTestCase(ModelTestCase):
+    """Tensorflow Keras test-case-provider fixture.
+
+    Implemented architectures are:
+    * "MLP":
+        - input: 64-dimensional features vectors
+        - stack: 32-neurons fully-connected layer with ReLU
+                 16-neurons fully-connected layer with ReLU
+                 1 output neuron with sigmoid activation
+    * "RNN":
+        - input: 128-tokens-sequence in a 100-tokens-vocabulary
+        - stack: 32-dimensional embedding matrix
+                 16-neurons LSTM layer with tanh activation
+                 1 output neuron with sigmoid activation
+    * "CNN":
+        - input: 64x64 image with 3 channels (normalized values)
+        - stack: 32 7x7 conv. filters, then 8x8 max pooling
+                 16 5x5 conv. filters, then 8x8 avg pooling
+                 1 output neuron with sigmoid activation
+    """
+
+    vector_cls = JaxNumpyVector
+    tensor_cls = jax.Array
+
+    def __init__(
+        self,
+        kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
+        device: Literal["cpu", "gpu"],
+    ) -> None:
+        """Specify the desired model architecture."""
+        if kind not in ("MLP", "MLP-tune", "RNN", "CNN"):
+            raise ValueError(f"Invalid test architecture: '{kind}'.")
+        if device not in ("cpu", "gpu"):
+            raise ValueError(f"Invalid device choice for test: '{device}'.")
+        self.kind = kind
+        self.device = device
+        set_device_policy(gpu=(device == "gpu"), idx=0)
+
+    @staticmethod
+    def to_numpy(
+        tensor: Any,
+    ) -> np.ndarray:
+        """Convert an input jax jax.Array to a numpy jax.Array."""
+        assert isinstance(tensor, jax.Array)
+        return np.asarray(tensor)
+
+    @property
+    def dataset(
+        self,
+    ) -> List[Batch]:
+        """Suited toy binary-classification dataset."""
+        rng = np.random.default_rng(seed=0)
+        if self.kind.startswith("MLP"):
+            inputs = rng.normal(size=(2, 32, 64)).astype("float32")
+        elif self.kind == "RNN":
+            inputs = rng.choice(100, size=(2, 32, 128))
+        elif self.kind == "CNN":
+            inputs = rng.normal(size=(2, 32, 64, 64, 3)).astype("float32")
+        labels = rng.choice(2, size=(2, 32))
+        inputs = jnp.asarray(inputs)  # type: ignore
+        labels = jnp.asarray(labels)  # type: ignore
+        batches = list(zip(inputs, labels, [None, None]))
+        return batches  # type: ignore
+
+    @property
+    def model(self) -> HaikuModel:
+        """Suited toy binary-classification haiku models."""
+        if self.kind == "CNN":
+            shape = [64, 64, 3]
+            model_fn = cnn_fn
+        elif self.kind.startswith("MLP"):
+            shape = [64]
+            model_fn = mlp_fn
+        elif self.kind == "RNN":
+            shape = [128]
+            model_fn = rnn_fn
+        model = HaikuModel(model_fn, loss_fn)
+        model.initialize(
+            {
+                "features_shape": shape,
+                "data_type": "int" if self.kind == "RNN" else "float32",
+            }
+        )
+        if self.kind == "MLP-tune":
+            names = model.get_weight_names()
+            model.set_trainable_weights([names[i] for i in range(3)])
+        return model
+
+    def assert_correct_device(
+        self,
+        vector: JaxNumpyVector,
+    ) -> None:
+        """Raise if a vector is backed on the wrong type of device."""
+        name = f"{self.device}:0"
+        assert all(
+            f"{arr.device().platform}:{arr.device().id}" == name
+            for arr in vector.coefs.values()
+        )
+
+    def get_trainable_criterion(
+        self,
+        c_type: str,
+    ) -> Union[
+        List[str],
+        Dict[str, Dict[str, Any]],
+        Callable[[str, str, jax.Array], bool],
+    ]:
+        "Build different weight freezing criteria"
+        if c_type == "names":
+            names = self.model.get_weight_names()
+            return [names[2], names[3]]
+        if c_type == "pytree":
+            params = getattr(self.model, "_params")
+            return {k: v for i, (k, v) in enumerate(params.items()) if i != 1}
+        if c_type == "predicate":
+            return lambda m, n, p: n != "b"
+        raise KeyError(f"Invalid 'c_type' parameter: {c_type}.")
+
+
+@pytest.fixture(name="test_case")
+def fixture_test_case(
+    kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
+    device: Literal["cpu", "gpu"],
+) -> HaikuTestCase:
+    """Fixture to access a TensorflowTestCase."""
+    return HaikuTestCase(kind, device)
+
+
+DEVICES = ["cpu"]
+if any(d.platform == "gpu" for d in jax.devices()):
+    DEVICES.append("gpu")
+
+
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("kind", ["MLP", "MLP-tune", "RNN", "CNN"])
+class TestHaikuModel(ModelTestSuite):
+    """Unit tests for declearn.model.tensorflow.TensorflowModel."""
+
+    @pytest.mark.filterwarnings("ignore: Our custom Haiku serialization")
+    def test_serialization(
+        self,
+        test_case: ModelTestCase,
+    ) -> None:
+        super().test_serialization(test_case)
+
+    @pytest.mark.parametrize(
+        "criterion_type", ["names", "pytree", "predicate"]
+    )
+    def test_get_frozen_weights(
+        self,
+        test_case: HaikuTestCase,
+        criterion_type: str,
+    ) -> None:
+        """Check that `get_weights` behaves properly with frozen weights."""
+        model = test_case.model  # type: HaikuModel
+        criterion = test_case.get_trainable_criterion(criterion_type)
+        model.set_trainable_weights(criterion)  # freeze some weights
+        w_all = model.get_weights()
+        w_trn = model.get_weights(trainable=True)
+        assert set(w_trn.coefs).issubset(w_all.coefs)  # check on keys
+        n_params = len(model.get_weight_names())
+        n_trainable = len(model.get_weight_names(trainable=True))
+        assert n_trainable < n_params
+        assert len(w_trn.coefs) == n_trainable
+        assert len(w_all.coefs) == n_params
+
+    @pytest.mark.parametrize(
+        "criterion_type", ["names", "pytree", "predicate"]
+    )
+    def test_set_frozen_weights(
+        self,
+        test_case: HaikuTestCase,
+        criterion_type: str,
+    ) -> None:
+        """Check that `set_weights` behaves properly with frozen weights."""
+        # similar code to TorchModel tests; pylint: disable=duplicate-code
+        # Setup a model with some frozen weights, and gather trainable ones.
+        model = test_case.model
+        criterion = test_case.get_trainable_criterion(criterion_type)
+        model.set_trainable_weights(criterion)  # freeze some weights
+        w_trn = model.get_weights(trainable=True)
+        # Test that `set_weights` works if and only if properly parametrized.
+        with pytest.raises(KeyError):
+            model.set_weights(w_trn)
+        with pytest.raises(KeyError):
+            model.set_weights(model.get_weights(), trainable=True)
+        model.set_weights(w_trn, trainable=True)
+
+    def test_proper_model_placement(
+        self,
+        test_case: HaikuTestCase,
+    ) -> None:
+        """Check that at instantiation, model weights are properly placed."""
+        model = test_case.model
+        policy = model.device_policy
+        assert policy.gpu == (test_case.device == "gpu")
+        assert policy.idx == 0
+        params = jax.tree_util.tree_leaves(getattr(model, "_params"))
+        device = f"{test_case.device}:0"
+        for arr in params:
+            assert f"{arr.device().platform}:{arr.device().id}" == device
diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py
index 2e89fb4c71f1f129a2ed2ab3d7d7eb014640bc14..a6da92e767b7db9760357a00777c46116d13e577 100644
--- a/test/model/test_tflow.py
+++ b/test/model/test_tflow.py
@@ -17,8 +17,8 @@
 
 """Unit tests for TensorflowModel."""
 
-import warnings
 import sys
+import warnings
 from typing import Any, List, Literal
 
 import numpy as np
@@ -38,7 +38,7 @@ from declearn.utils import set_device_policy
 # dirty trick to import from `model_testing.py`;
 # pylint: disable=wrong-import-order, wrong-import-position
 sys.path.append(".")
-from model_testing import ModelTestSuite, ModelTestCase
+from model_testing import ModelTestCase, ModelTestSuite
 
 
 class TensorflowTestCase(ModelTestCase):
@@ -51,7 +51,7 @@ class TensorflowTestCase(ModelTestCase):
                  16-neurons fully-connected layer with ReLU
                  1 output neuron with sigmoid activation
     * "MLP-tune":
-        - same as NLP, but freeze the first layer of the stack
+        - same as MLP, but freeze the first layer of the stack
     * "RNN":
         - input: 128-tokens-sequence in a 100-tokens-vocabulary
         - stack: 32-dimensional embedding matrix