Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit d1c3f8c6 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix support for more recent JAX versions.

Recent JAX versions have deprecated, then dropped, the `device()`
method of jax (numpy) arrays, in favor of the already-existing,
distinct-return-type `devices()` one.

This commit therefore replaces calls to `device` with operations
that include calling `devices` and extracting its content. This
is not very elegant, but seems to be the proper way to do things
based on the current JAX documentation.
parent 2dc00b65
No related branches found
No related tags found
No related merge requests found
......@@ -39,6 +39,18 @@ __all__ = [
jax.config.update("jax_enable_x64", True) # enable float64 support
def get_array_device(array: jax.Array) -> jax.Device:
"""Return the Device on which the input array is placed."""
devices = array.devices()
if len(devices) > 1: # pragma: no cover
raise RuntimeError(
f"A jax Array is placed on multiple devices: '{devices}'. "
"This is unsupported by DecLearn as of now. Please report "
"this bug to the development team."
)
return list(devices)[0]
@register_vector_type(
jax.Array,
jaxlib.xla_extension.ArrayImpl, # pylint: disable=c-extension-no-member
......@@ -113,7 +125,7 @@ class JaxNumpyVector(Vector):
# Ensure 'other' JaxNumpyVector shares this vector's device placement.
if isinstance(other, JaxNumpyVector):
coefs = {
key: jax.device_put(val, self.coefs[key].device())
key: jax.device_put(val, get_array_device(self.coefs[key]))
for key, val in other.coefs.items()
}
other = JaxNumpyVector(coefs)
......@@ -123,7 +135,10 @@ class JaxNumpyVector(Vector):
valid = isinstance(other, JaxNumpyVector)
valid = valid and (self.coefs.keys() == other.coefs.keys())
return valid and all(
jnp.array_equal(self.coefs[k], other.coefs[k]) for k in self.coefs
jnp.array_equal(
val, jax.device_put(other.coefs[key], get_array_device(val))
)
for key, val in self.coefs.items()
)
def sign(
......
......@@ -184,8 +184,9 @@ class HaikuTestCase(ModelTestCase):
"""Raise if a vector is backed on the wrong type of device."""
name = f"{self.device}:0"
assert all(
f"{arr.device().platform}:{arr.device().id}" == name
f"{device.platform}:{device.id}" == name
for arr in vector.coefs.values()
for device in arr.devices()
)
def get_trainable_criterion(
......@@ -296,4 +297,6 @@ class TestHaikuModel(ModelTestSuite):
params = jax.tree_util.tree_leaves(getattr(model, "_params"))
device = f"{test_case.device}:0"
for arr in params:
assert f"{arr.device().platform}:{arr.device().id}" == device
assert len(arr.devices()) == 1
arr_dev = list(arr.devices())[0]
assert f"{arr_dev.platform}:{arr_dev.id}" == device
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment