Skip to content

Conversation

@crusaderky
Copy link
Contributor

@crusaderky crusaderky commented May 28, 2025

Related:

Tested locally vs. jax==0.4.31.

Benchmark

>>> import importlib >>> from array_api_compat import array_namespace >>> for xp_name in ("numpy", "cupy", "dask.array", "torch", "jax.numpy", "sparse", "ndonnx"): ... print(xp_name) ... xp = importlib.import_module(xp_name) ... a = xp.asarray(1) ... %timeit array_namespace(a)

Before

numpy 1.31 μs ± 19.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) cupy 1.36 μs ± 5.16 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) dask.array 1.77 μs ± 33.6 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) torch 1.52 μs ± 4.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) jax.numpy 1.4 μs ± 19.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) sparse 1.75 μs ± 9.39 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) ndonnx 1.99 μs ± 10.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) 

After

numpy 673 ns ± 11.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) cupy 494 ns ± 8.45 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) dask.array 498 ns ± 3.87 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) torch 510 ns ± 5.91 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) jax.numpy 521 ns ± 3.06 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) sparse 766 ns ± 2.13 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) ndonnx 727 ns ± 9.62 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) 

Note that the extra slowness of numpy is due to jax-ml/jax#20620

if library == "ndonnx" and api_version in ("2021.12", "2022.12"):
pytest.skip("Unsupported API version")
if (library == "sparse" and api_version in ("2023.12", "2024.12")) or (
library == "jax.numpy" and api_version in ("2021.12", "2022.12", "2023.12")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this test requires jax>=0.6.1 to pass with api_version=2024.12. I didn't see much of a point adding special handling for older versions.

@ev-br
Copy link
Member

ev-br commented Jun 2, 2025

Could you please resolve the conflict

@crusaderky
Copy link
Contributor Author

@ev-br fixed

@ev-br
Copy link
Member

ev-br commented Jun 2, 2025

Grumble grumble:

========================================================================================= short test summary info ========================================================================================= FAILED tests/test_array_namespace.py::test_array_namespace[jax.numpy-2024.12-False] - ValueError: api_version='2024.12' is not available; available versions are: ['2023.12'] FAILED tests/test_array_namespace.py::test_array_namespace[jax.numpy-2024.12-None] - ValueError: api_version='2024.12' is not available; available versions are: ['2023.12'] ============================================================== 2 failed, 410 passed, 105 skipped, 6 xfailed, 2 xpassed, 46 warnings in 7.66s ============================================================== 

EDIT: sorry, no; my local Jax copy was too old. Nevermind.

@ev-br ev-br merged commit 6ae28ee into data-apis:main Jun 2, 2025
23 checks passed
@ev-br
Copy link
Member

ev-br commented Jun 2, 2025

Thanks @crusaderky , merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants