Skip to content

Commit 39bb876

Browse files
committed
fix jax versions
1 parent 082be03 commit 39bb876

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ jax_core_deps = [
103103
"protobuf==4.25.5",
104104
]
105105
jax_cpu = [
106-
"jax==0.4.26",
107-
"jaxlib==0.4.26",
106+
"jax==0.4.28",
107+
"jaxlib==0.4.28",
108108
"algorithmic_efficiency[jax_core_deps]",
109109
]
110110
jax_gpu = [
111-
"jax==0.4.26",
112-
"jaxlib==0.4.26",
111+
"jax==0.4.28",
112+
"jaxlib==0.4.28",
113113
"jax-cuda12-plugin[with_cuda]==0.4.28",
114114
"jax-cuda12-pjrt==0.4.28",
115115
"algorithmic_efficiency[jax_core_deps]",

0 commit comments

Comments
 (0)