Fix support for more recent JAX versions.
This single-commit PR is a straightforward fix to support recent JAX versions.
It will be backported as a subminor patch to DecLearn 2.5, and perhaps older versions.
Recent JAX versions have deprecated, then dropped, the device()
method of jax (numpy) arrays, in favor of the already-existing,
distinct-return-type devices()
one.
This commit therefore replaces calls to device
with operations
that include calling devices
and extracting its content. This
is not very elegant, but seems to be the proper way to do things
based on the current JAX documentation.