diff --git a/declearn/model/haiku/_vector.py b/declearn/model/haiku/_vector.py
index 6b391857d3d365125a76b8e926064ef3a418d058..a63921e34d690b532435c31ac708318d5be0ab79 100644
--- a/declearn/model/haiku/_vector.py
+++ b/declearn/model/haiku/_vector.py
@@ -39,6 +39,18 @@ __all__ = [
 jax.config.update("jax_enable_x64", True)  # enable float64 support
 
 
+def get_array_device(array: jax.Array) -> jax.Device:
+    """Return the Device on which the input array is placed."""
+    devices = array.devices()
+    if len(devices) > 1:  # pragma: no cover
+        raise RuntimeError(
+            f"A jax Array is placed on multiple devices: '{devices}'. "
+            "This is unsupported by DecLearn as of now. Please report "
+            "this bug to the development team."
+        )
+    return list(devices)[0]
+
+
 @register_vector_type(
     jax.Array,
     jaxlib.xla_extension.ArrayImpl,  # pylint: disable=c-extension-no-member
@@ -113,7 +125,7 @@ class JaxNumpyVector(Vector):
         # Ensure 'other' JaxNumpyVector shares this vector's device placement.
         if isinstance(other, JaxNumpyVector):
             coefs = {
-                key: jax.device_put(val, self.coefs[key].device())
+                key: jax.device_put(val, get_array_device(self.coefs[key]))
                 for key, val in other.coefs.items()
             }
             other = JaxNumpyVector(coefs)
@@ -123,7 +135,10 @@ class JaxNumpyVector(Vector):
         valid = isinstance(other, JaxNumpyVector)
         valid = valid and (self.coefs.keys() == other.coefs.keys())
         return valid and all(
-            jnp.array_equal(self.coefs[k], other.coefs[k]) for k in self.coefs
+            jnp.array_equal(
+                val, jax.device_put(other.coefs[key], get_array_device(val))
+            )
+            for key, val in self.coefs.items()
         )
 
     def sign(
diff --git a/test/model/test_haiku_model.py b/test/model/test_haiku_model.py
index 3d4e44a000ae45bb64693870e3ee9283cbb93863..bb623b2419989387368fff58d1fbd1a4f1b8b44b 100644
--- a/test/model/test_haiku_model.py
+++ b/test/model/test_haiku_model.py
@@ -184,8 +184,9 @@ class HaikuTestCase(ModelTestCase):
         """Raise if a vector is backed on the wrong type of device."""
         name = f"{self.device}:0"
         assert all(
-            f"{arr.device().platform}:{arr.device().id}" == name
+            f"{device.platform}:{device.id}" == name
             for arr in vector.coefs.values()
+            for device in arr.devices()
         )
 
     def get_trainable_criterion(
@@ -296,4 +297,6 @@ class TestHaikuModel(ModelTestSuite):
         params = jax.tree_util.tree_leaves(getattr(model, "_params"))
         device = f"{test_case.device}:0"
         for arr in params:
-            assert f"{arr.device().platform}:{arr.device().id}" == device
+            assert len(arr.devices()) == 1
+            arr_dev = list(arr.devices())[0]
+            assert f"{arr_dev.platform}:{arr_dev.id}" == device