Warning
This is an early experimental package. Feedback wanted!
A Python package to provide an easy to use MPI-backend for JAX sharding, built on top of MPIWrapper. No special operations, 100% native JAX.
# Using uv (recommended) uv add git+https://github.com/mpi4jax/mpibackend4jax # Using pip pip install git+https://github.com/mpi4jax/mpibackend4jaxSimply import the package before using JAX with MPI:
import mpibackend4jax as _mpi4jax # noqa: F401 import jax print("Setup initialize", flush=True) jax.distributed.initialize() print(f"{jax.process_index()}/{jax.process_count()} :", jax.local_devices()) print(f"{jax.process_index()}/{jax.process_count()} :", jax.devices()) x = jax.numpy.ones( (jax.device_count(),), device=jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), "i"), jax.sharding.PartitionSpec("i") ), ) print(f"{jax.process_index()}/{jax.process_count()} :", x.sum())Run with MPI:
mpirun -np 2 python examples/example.pyWhen you import mpibackend4jax, it automatically:
- Sets
MPITRAMPOLINE_LIBto point to the builtlibmpiwrapper.so - Sets
JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi
- CMake (for building MPIWrapper)
- A working MPI implementation (e.g., OpenMPI, MPICH)
- JAX
Tested on macOS with MPICH.
You can check if MPITrampoline is properly configured:
import mpibackend4jax if mpibackend4jax.is_configured(): print("MPITrampoline is properly configured!") else: print("MPITrampoline configuration failed.")Special thanks to @inailuig (Clemens Giuliani) for adding MPI support in XLA, which makes this integration possible.