diff --git a/test/model/test_haiku.py b/test/model/test_haiku.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb5dea327a6b0eff272894be80a3d082c075ac83
--- /dev/null
+++ b/test/model/test_haiku.py
@@ -0,0 +1,197 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for TensorflowModel."""
+
+import sys
+from functools import partial
+from typing import Any, List, Literal
+
+import numpy as np
+import pytest
+
+try:
+    import chex
+    import haiku as hk
+    import jax
+    import jax.numpy as jnp
+    from jax.config import config as jaxconfig
+    from jaxtyping import Array
+except ModuleNotFoundError:
+    pytest.skip("jax and/or haiku are unavailable", allow_module_level=True)
+
+from declearn.model.haiku import HaikuModel, JaxNumpyVector
+from declearn.typing import Batch
+from declearn.utils import set_device_policy
+
+# dirty trick to import from `model_testing.py`;
+# pylint: disable=wrong-import-order, wrong-import-position
+sys.path.append(".")
+from model_testing import ModelTestCase, ModelTestSuite
+
+# Overriding float32 default in jax
+jaxconfig.update("jax_enable_x64", True)
+
+# TODO : Implement LSTM with unroll_net, check https://github.com/deepmind/dm-haiku/blob/main/examples/haiku_lstms.ipynb
+# TODO add chex variants
+
+
+def cnn_fn(x: Array) -> jnp.ndarray:
+    """Simple CNN in a purely functional form"""
+    model = hk.Sequential(
+        [
+            hk.Conv2D(output_channels=32, kernel_shape=(7, 7), padding="SAME"),
+            jax.nn.relu,
+            hk.MaxPool(window_shape=(8, 8), strides=(1, 1), padding="SAME"),
+            hk.Conv2D(output_channels=16, kernel_shape=(5, 5), padding="SAME"),
+            jax.nn.relu,
+            hk.AvgPool(window_shape=(8, 8), strides=(1, 1), padding="SAME"),
+            hk.Flatten(),
+            hk.Linear(1),
+        ]
+    )
+    return model(x)
+
+
+def mlp_fn(x: Array) -> jnp.ndarray:
+    """Simple MLP in a purely functional form"""
+    model = hk.nets.MLP([32, 16, 1])
+    return model(x)
+
+
+def loss_fn(y_pred: Array, y_true: Array) -> Array:
+    """Per-sample binary cross entropy"""
+    y_pred = jax.nn.sigmoid(y_pred)
+    y_pred = jnp.squeeze(y_pred)
+    log_p, log_not_p = jnp.log(y_pred), jnp.log(1.0 - y_pred)
+    return -y_true * log_p - (1.0 - y_true) * log_not_p
+
+
+class HaikuTestCase(ModelTestCase):
+    """Tensorflow Keras test-case-provider fixture.
+
+    Implemented architectures are:
+    * "MLP":
+        - input: 64-dimensional features vectors
+        - stack: 32-neurons fully-connected layer with ReLU
+                 16-neurons fully-connected layer with ReLU
+                 1 output neuron with sigmoid activation
+    * "CNN":
+        - input: 64x64 image with 3 channels (normalized values)
+        - stack: 32 7x7 conv. filters, then 8x8 max pooling
+                 16 5x5 conv. filters, then 8x8 avg pooling
+                 1 output neuron with sigmoid activation
+    """
+
+    vector_cls = JaxNumpyVector
+    tensor_cls = Array
+
+    def __init__(
+        self,
+        kind: Literal["MLP", "CNN"],
+        device: Literal["cpu", "gpu"],
+    ) -> None:
+        """Specify the desired model architecture."""
+        if kind not in ("MLP", "CNN"):
+            raise ValueError(f"Invalid test architecture: '{kind}'.")
+        if device not in ("cpu", "gpu"):
+            raise ValueError(f"Invalid device choice for test: '{device}'.")
+        self.kind = kind
+        self.device = device
+        set_device_policy(gpu=(device == "gpu"), idx=0)
+
+    @staticmethod
+    def to_numpy(
+        tensor: Any,
+    ) -> np.ndarray:
+        """Convert an input jax array to a numpy array."""
+        assert isinstance(tensor, Array)
+        return np.asarray(tensor)
+
+    @property
+    def dataset(
+        self,
+    ) -> List[Batch]:
+        """Suited toy binary-classification dataset."""
+        rng = np.random.default_rng(seed=0)
+        if self.kind.startswith("MLP"):
+            inputs = rng.normal(size=(2, 32, 64))
+        elif self.kind == "CNN":
+            inputs = rng.normal(size=(2, 32, 64, 64, 3))
+        labels = rng.choice([0, 1], size=(2, 32)).astype(float)
+        inputs = jnp.asarray(inputs)
+        labels = jnp.asarray(labels)
+        batches = list(zip(inputs, labels, [None, None]))
+        return batches
+
+    @property
+    def model(self) -> HaikuModel:
+        """Suited toy binary-classification haiku models."""
+        if self.kind == "CNN":
+            shape = [64, 64, 3]
+            model_fn = cnn_fn
+        elif self.kind == "MLP":
+            shape = [64]
+            model_fn = mlp_fn
+        model = HaikuModel(model_fn, loss_fn)
+        model.initialize({"features_shape": shape, "data_type": "float64"})
+        return model
+
+    def assert_correct_device(
+        self,
+        vector: JaxNumpyVector,
+    ) -> None:
+        """Raise if a vector is backed on the wrong type of device."""
+        name = f"{self.device}:0"
+        assert all(
+            f"{arr.device().platform}:{arr.device().id}" == name
+            for arr in vector.coefs.values()
+        )
+
+
+@pytest.fixture(name="test_case")
+def fixture_test_case(
+    kind: Literal["MLP", "CNN"],
+    device: Literal["cpu", "gpu"],
+) -> HaikuTestCase:
+    """Fixture to access a TensorflowTestCase."""
+    return HaikuTestCase(kind, device)
+
+
+DEVICES = ["cpu"]
+if "gpu" in [d.platform for d in jax.devices()]:
+    DEVICES.append("gpu")
+
+
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("kind", ["MLP", "CNN"])
+class TestTensorflowModel(ModelTestSuite):
+    """Unit tests for declearn.model.tensorflow.TensorflowModel."""
+
+    def test_proper_model_placement(
+        self,
+        test_case: HaikuTestCase,
+    ) -> None:
+        """Check that at instantiation, model weights are properly placed."""
+        model = test_case.model
+        policy = model.device_policy
+        assert policy.gpu == (test_case.device == "gpu")
+        assert policy.idx == 0
+        params = getattr(model, "_params_leaves")
+        device = f"{test_case.device}:0"
+        for arr in params:
+            assert f"{arr.device().platform}:{arr.device().id}" == device