From d1c3f8c6f6780cc6eec8035a515842c0b284876c Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Wed, 19 Jun 2024 10:24:17 +0200 Subject: [PATCH] Fix support for more recent JAX versions. Recent JAX versions have deprecated, then dropped, the `device()` method of jax (numpy) arrays, in favor of the already-existing, distinct-return-type `devices()` one. This commit therefore replaces calls to `device` with operations that include calling `devices` and extracting its content. This is not very elegant, but seems to be the proper way to do things based on the current JAX documentation. --- declearn/model/haiku/_vector.py | 19 +++++++++++++++++-- test/model/test_haiku_model.py | 7 +++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/declearn/model/haiku/_vector.py b/declearn/model/haiku/_vector.py index 6b391857..a63921e3 100644 --- a/declearn/model/haiku/_vector.py +++ b/declearn/model/haiku/_vector.py @@ -39,6 +39,18 @@ __all__ = [ jax.config.update("jax_enable_x64", True) # enable float64 support +def get_array_device(array: jax.Array) -> jax.Device: + """Return the Device on which the input array is placed.""" + devices = array.devices() + if len(devices) > 1: # pragma: no cover + raise RuntimeError( + f"A jax Array is placed on multiple devices: '{devices}'. " + "This is unsupported by DecLearn as of now. Please report " + "this bug to the development team." + ) + return list(devices)[0] + + @register_vector_type( jax.Array, jaxlib.xla_extension.ArrayImpl, # pylint: disable=c-extension-no-member @@ -113,7 +125,7 @@ class JaxNumpyVector(Vector): # Ensure 'other' JaxNumpyVector shares this vector's device placement. if isinstance(other, JaxNumpyVector): coefs = { - key: jax.device_put(val, self.coefs[key].device()) + key: jax.device_put(val, get_array_device(self.coefs[key])) for key, val in other.coefs.items() } other = JaxNumpyVector(coefs) @@ -123,7 +135,10 @@ class JaxNumpyVector(Vector): valid = isinstance(other, JaxNumpyVector) valid = valid and (self.coefs.keys() == other.coefs.keys()) return valid and all( - jnp.array_equal(self.coefs[k], other.coefs[k]) for k in self.coefs + jnp.array_equal( + val, jax.device_put(other.coefs[key], get_array_device(val)) + ) + for key, val in self.coefs.items() ) def sign( diff --git a/test/model/test_haiku_model.py b/test/model/test_haiku_model.py index 3d4e44a0..bb623b24 100644 --- a/test/model/test_haiku_model.py +++ b/test/model/test_haiku_model.py @@ -184,8 +184,9 @@ class HaikuTestCase(ModelTestCase): """Raise if a vector is backed on the wrong type of device.""" name = f"{self.device}:0" assert all( - f"{arr.device().platform}:{arr.device().id}" == name + f"{device.platform}:{device.id}" == name for arr in vector.coefs.values() + for device in arr.devices() ) def get_trainable_criterion( @@ -296,4 +297,6 @@ class TestHaikuModel(ModelTestSuite): params = jax.tree_util.tree_leaves(getattr(model, "_params")) device = f"{test_case.device}:0" for arr in params: - assert f"{arr.device().platform}:{arr.device().id}" == device + assert len(arr.devices()) == 1 + arr_dev = list(arr.devices())[0] + assert f"{arr_dev.platform}:{arr_dev.id}" == device -- GitLab