diff --git a/pyproject.toml b/pyproject.toml index 90873cd295b2a3c92469e492179da249a9981669..11f8094e2f7a2c26d93e775db9f5daa590b9a129 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ all = [ # all non-tests extra dependencies "dm-haiku == 0.0.9", "functorch", "grpcio >= 1.45", - "jax[cpu] == 0.4", + "jax[cpu] >= 0.4, < 0.5", "opacus ~= 1.1", "protobuf >= 3.19", "tensorflow ~= 2.5", @@ -67,7 +67,7 @@ grpc = [ ] haiku = [ "dm-haiku == 0.0.9", - "jax[cpu] == 0.4", # NOTE: GPU support must be manually installed + "jax[cpu] >= 0.4, < 0.5", # NOTE: GPU support must be manually installed ] tensorflow = [ "tensorflow ~= 2.5",