diff --git a/declearn/model/haiku/_vector.py b/declearn/model/haiku/_vector.py index 6b391857d3d365125a76b8e926064ef3a418d058..a63921e34d690b532435c31ac708318d5be0ab79 100644 --- a/declearn/model/haiku/_vector.py +++ b/declearn/model/haiku/_vector.py @@ -39,6 +39,18 @@ __all__ = [ jax.config.update("jax_enable_x64", True) # enable float64 support +def get_array_device(array: jax.Array) -> jax.Device: + """Return the Device on which the input array is placed.""" + devices = array.devices() + if len(devices) > 1: # pragma: no cover + raise RuntimeError( + f"A jax Array is placed on multiple devices: '{devices}'. " + "This is unsupported by DecLearn as of now. Please report " + "this bug to the development team." + ) + return list(devices)[0] + + @register_vector_type( jax.Array, jaxlib.xla_extension.ArrayImpl, # pylint: disable=c-extension-no-member @@ -113,7 +125,7 @@ class JaxNumpyVector(Vector): # Ensure 'other' JaxNumpyVector shares this vector's device placement. if isinstance(other, JaxNumpyVector): coefs = { - key: jax.device_put(val, self.coefs[key].device()) + key: jax.device_put(val, get_array_device(self.coefs[key])) for key, val in other.coefs.items() } other = JaxNumpyVector(coefs) @@ -123,7 +135,10 @@ class JaxNumpyVector(Vector): valid = isinstance(other, JaxNumpyVector) valid = valid and (self.coefs.keys() == other.coefs.keys()) return valid and all( - jnp.array_equal(self.coefs[k], other.coefs[k]) for k in self.coefs + jnp.array_equal( + val, jax.device_put(other.coefs[key], get_array_device(val)) + ) + for key, val in self.coefs.items() ) def sign( diff --git a/test/model/test_haiku_model.py b/test/model/test_haiku_model.py index 3d4e44a000ae45bb64693870e3ee9283cbb93863..bb623b2419989387368fff58d1fbd1a4f1b8b44b 100644 --- a/test/model/test_haiku_model.py +++ b/test/model/test_haiku_model.py @@ -184,8 +184,9 @@ class HaikuTestCase(ModelTestCase): """Raise if a vector is backed on the wrong type of device.""" name = f"{self.device}:0" assert all( - f"{arr.device().platform}:{arr.device().id}" == name + f"{device.platform}:{device.id}" == name for arr in vector.coefs.values() + for device in arr.devices() ) def get_trainable_criterion( @@ -296,4 +297,6 @@ class TestHaikuModel(ModelTestSuite): params = jax.tree_util.tree_leaves(getattr(model, "_params")) device = f"{test_case.device}:0" for arr in params: - assert f"{arr.device().platform}:{arr.device().id}" == device + assert len(arr.devices()) == 1 + arr_dev = list(arr.devices())[0] + assert f"{arr_dev.platform}:{arr_dev.id}" == device