Skip to content
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602)
- Fix `ot.da.sinkhorn_lpl1_mm` compatibility with JAX (PR #592)
- Fiw linesearch import error on Scipy 1.14 (PR #642, Issue #641)
- Upgrade supported JAX versions from jax<=0.4.24 to jax<=0.4.30 (PR #643)

## 0.9.3
*January 2024*
Expand Down
16 changes: 14 additions & 2 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
import jax.scipy.special as jspecial
from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
jax_new_version = float('.'.join(jax.__version__.split('.')[1:])) > 4.24
except ImportError:
jax = False
jax_type = float
Expand Down Expand Up @@ -1439,11 +1440,19 @@ def __init__(self):
jax.device_put(jnp.array(1, dtype=jnp.float64), d)
]

self.jax_new_version = jax_new_version

def _to_numpy(self, a):
return np.array(a)

def _get_device(self, a):
if self.jax_new_version:
return list(a.devices())[0]
else:
return a.device_buffer.device()

def _change_device(self, a, type_as):
return jax.device_put(a, type_as.device_buffer.device())
return jax.device_put(a, self._get_device(type_as))

def _from_numpy(self, a, type_as=None):
if isinstance(a, float):
Expand Down Expand Up @@ -1688,7 +1697,10 @@ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)

def dtype_device(self, a):
return a.dtype, a.device_buffer.device()
if self.jax_new_version:
return a.dtype, list(a.devices())[0]
else:
return a.dtype, a.device_buffer.device()

def assert_same_dtype_device(self, a, b):
a_dtype, a_device = self.dtype_device(a)
Expand Down
4 changes: 2 additions & 2 deletions requirements_all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ pymanopt @ git+https://github.com/pymanopt/pymanopt.git@master
cvxopt
scikit-learn
torch
jax<=0.4.24
jaxlib<=0.4.24
jax
jaxlib
tensorflow
pytest
torch_geometric
Expand Down