Documentation: https://array-api-jit.readthedocs.io
Source Code: https://github.com/34j/array-api-jit
JIT decorator supporting multiple array API compatible libraries
Install this via pip (or your favourite package manager):
pip install array-api-jit
Simply decorate your function with @jit()
:
from array_api_jit import jit @jit() def my_function(x: Any) -> Any: xp = array_namespace(x) return xp.sin(x) + xp.cos(x)
You can specify the decorator, arguments, and keyword arguments for each library.
from array_api_jit import jit from array_api_compat import array_namespace from typing import Any import numba @jit( {"numpy": numba.jit()}, # numba.jit is not used by default because it may not succeed decorator_kwargs={ "jax": {"static_argnames": ["n"]} }, # jax requires for-loop variable to be "static_argnames" # fail_on_error: bool = False, # do not raise an error if the decorator fails (Default) # rerun_on_error: bool = True, # re-run the original function if the wrapped function fails (NOT Default) ) def sin_n_times(x: Any, n: int) -> Any: xp = array_namespace(x) for i in range(n): x = xp.sin(x) return x
Thanks goes to these wonderful people (emoji key):
This project follows the all-contributors specification. Contributions of any kind welcome!
This package was created with Copier and the browniebroke/pypackage-template project template.