Skip to content

Conversation

@YannCabanes
Copy link
Contributor

Add the JAX backend: https://jax.readthedocs.io/en/latest/
The JAX backend can be used for automatic differentiation: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

@YannCabanes
Copy link
Contributor Author

The following error is obtained:

tslearn/tests/test_metrics.py:15: in <module> backends = [Backend("numpy"), None] tslearn/backend/backend.py:99: in __init__ self.backend = select_backend(data) tslearn/backend/backend.py:75: in select_backend backends_instances = [NumPyBackend(), JAXBackend(), PyTorchBackend()] tslearn/backend/backend.py:13: in __init__ raise ValueError("Could not use JAX backend since JAX is not installed.") E ValueError: Could not use JAX backend since JAX is not installed. 
@YannCabanes
Copy link
Contributor Author

It seems from the previous error message that JAX is not installed during the continuous integration tests, therefore the tests in test_metrics.py are not running with the JAX backend.

On my local computer, these tests are failing with the following error message:

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html 
@YannCabanes
Copy link
Contributor Author

A solution would be to create a class named JAXNumPyInterface which would define the operators of mutable objects.
The Python operators are listed here: https://docs.python.org/3/library/operator.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

1 participant