From dd2eef5029ab0a75785372326d5e3fe472c72799 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Wed, 10 May 2023 14:49:12 +0200
Subject: [PATCH] Fix jax dependency specification.

---
 pyproject.toml | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index e4864300..3ce5cc8b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -48,7 +48,6 @@ dependencies = [
 [project.optional-dependencies]
 all = [  # all non-tests extra dependencies
     "dm-haiku == 0.0.9",
-    "jax == 0.4.4",
     "functorch",
     "grpcio >= 1.45",
     "jax[cpu] == 0.4.4",
@@ -67,8 +66,8 @@ grpc = [
     "protobuf >= 3.19",
 ]
 haiku = [
-    "jax == 0.4.4",
     "dm-haiku == 0.0.9",
+    "jax[cpu] == 0.4.4",
 ]
 tensorflow = [
     "tensorflow ~= 2.5",
-- 
GitLab