diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py
index 5442e59abf8d1bf46258fc4226c40941b15b78c7..69f3e34f7c21483901a0acdc8c4a822331126ae4 100644
--- a/declearn/model/haiku/_model.py
+++ b/declearn/model/haiku/_model.py
@@ -55,7 +55,7 @@ __all__ = [
 
 # alias for unpacked Batch structures, converted to jax arrays
 # FUTURE: add support for lists of inputs
-JaxBatch = Tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]
+JaxBatch = Tuple[List[jax.Array], Optional[jax.Array], Optional[jax.Array]]
 
 
 @register_type(name="HaikuModel", group="Model")
@@ -334,7 +334,7 @@ class HaikuModel(Model):
             must be a function taking in the name of the module (e.g.
             layer name), the element name (e.g. parameter name) and the
             corresponding data and returning a boolean. See
-            [the haiku doc](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.data_structures.partition) # noqa
+            [the haiku doc](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.data_structures.partition)
             for details. If a list of integers, should represent the index of
             trainable  parameters in the parameter tree leaves. If a dict, should
             be formatted as a pytree.
@@ -410,7 +410,7 @@ class HaikuModel(Model):
         # FUTURE: add support for lists of inputs
         inputs, y_true, s_wght = batch
         params = hk.data_structures.merge(train_params, fixed_params)
-        y_pred = self._transformed_model.apply(params, rng, inputs)
+        y_pred = self._transformed_model.apply(params, rng, *inputs)
         s_loss = self._loss_fn(y_pred, y_true)  # type: ignore
         if s_wght is not None:
             s_loss = s_loss * s_wght
@@ -480,7 +480,10 @@ class HaikuModel(Model):
     ) -> Callable[
         [hk.Params, hk.Params, jax.Array, JaxBatch, float], jax.Array
     ]:
-        """Lazy-built jax function to compute clipped sample-wise gradients."""
+        """Lazy-built jax function to compute clipped sample-wise gradients.
+
+        Note : The vmap in_axis parameters work thank to the jax feature of
+        applying optional parameters to pytrees."""
 
         def clipped_grad_fn(
             train_params: hk.Params,
@@ -514,13 +517,6 @@ class HaikuModel(Model):
         """Unpack and enforce jnp.array conversion to an input data batch."""
 
         def convert(data: Any) -> Optional[jax.Array]:
-            if isinstance(data, (list, tuple)):
-                if len(data) == 1:
-                    data = data[0]
-                else:
-                    raise TypeError(
-                        "HaikuModel does not support multi-arrays inputs."
-                    )
             if (data is None) or isinstance(data, jax.Array):
                 return data
             if isinstance(data, np.ndarray):
@@ -529,7 +525,9 @@ class HaikuModel(Model):
 
         # Convert batched data to jax Arrays.
         inputs, y_true, s_wght = batch
-        output = [convert(inputs), convert(y_true), convert(s_wght)]
+        if not isinstance(inputs, (tuple, list)):
+            inputs = [inputs]
+        output = [list(map(convert, inputs)), convert(y_true), convert(s_wght)]
         return output  # type: ignore
 
     def apply_updates(  # type: ignore  # Vector subtype specification
@@ -556,7 +554,7 @@ class HaikuModel(Model):
             )
         params = self.get_named_weights()
         y_pred = np.asarray(
-            self._transformed_model.apply(params, next(self._rng_gen), inputs)
+            self._transformed_model.apply(params, next(self._rng_gen), *inputs)
         )
         y_true = np.asarray(y_true)  # type: ignore
         if isinstance(s_wght, jax.Array):