-
- Downloads
Fix support for more recent JAX versions.
Recent JAX versions have deprecated, then dropped, the `device()` method of jax (numpy) arrays, in favor of the already-existing, distinct-return-type `devices()` one. This commit therefore replaces calls to `device` with operations that include calling `devices` and extracting its content. This is not very elegant, but seems to be the proper way to do things based on the current JAX documentation.
Loading
Please register or sign in to comment