From f7a13dff9ce3de674ac263628ae8dc98153b2335 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 11 May 2023 15:58:17 +0200
Subject: [PATCH] Fix Jax dependency specification.

---
 pyproject.toml | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index 90873cd2..11f8094e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -50,7 +50,7 @@ all = [  # all non-tests extra dependencies
     "dm-haiku == 0.0.9",
     "functorch",
     "grpcio >= 1.45",
-    "jax[cpu] == 0.4",
+    "jax[cpu] >= 0.4, < 0.5",
     "opacus ~= 1.1",
     "protobuf >= 3.19",
     "tensorflow ~= 2.5",
@@ -67,7 +67,7 @@ grpc = [
 ]
 haiku = [
     "dm-haiku == 0.0.9",
-    "jax[cpu] == 0.4",  # NOTE: GPU support must be manually installed
+    "jax[cpu] >= 0.4, < 0.5",  # NOTE: GPU support must be manually installed
 ]
 tensorflow = [
     "tensorflow ~= 2.5",
-- 
GitLab