From f7a13dff9ce3de674ac263628ae8dc98153b2335 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Thu, 11 May 2023 15:58:17 +0200 Subject: [PATCH] Fix Jax dependency specification. --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 90873cd2..11f8094e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ all = [ # all non-tests extra dependencies "dm-haiku == 0.0.9", "functorch", "grpcio >= 1.45", - "jax[cpu] == 0.4", + "jax[cpu] >= 0.4, < 0.5", "opacus ~= 1.1", "protobuf >= 3.19", "tensorflow ~= 2.5", @@ -67,7 +67,7 @@ grpc = [ ] haiku = [ "dm-haiku == 0.0.9", - "jax[cpu] == 0.4", # NOTE: GPU support must be manually installed + "jax[cpu] >= 0.4, < 0.5", # NOTE: GPU support must be manually installed ] tensorflow = [ "tensorflow ~= 2.5", -- GitLab