JAX prefers immutable objects but neither Python nor JAX provide an immutable dictionary. 😢
This repository defines a light-weight immutable map (lower-level than a dict) that JAX understands as a PyTree. 🎉 🕶️
pip install xmmutablemapusing uv
uv add xmmutablemapfrom source, using pip
pip install git+https://github.com/GalacticDynamics/xmmutablemap.gitbuilding from source
cd /path/to/parent git clone https://github.com/GalacticDynamics/xmmutablemap.git cd xmmutablemap pip install -e . # editable modexmutablemap provides the class ImmutableMap, which is a full implementation of Python's Mapping ABC. If you've used a dict then you already know how to use ImmutableMap! The things ImmutableMap adds is 1) immutability (and related benefits like hashability) and 2) compatibility with JAX.
from xmmutablemap import ImmutableMap print(ImmutableMap(a=1, b=2, c=3)) # ImmutableMap({'a': 1, 'b': 2, 'c': 3}) print(ImmutableMap({"a": 1, "b": 2.0, "c": "3"})) # ImmutableMap({'a': 1, 'b': 2.0, 'c': '3'})One of the key benefits of ImmutableMap is its compatibility with JAX. Since it's immutable and hashable, it can be used in places where JAX would normally complain about mutable objects like regular dictionaries.
Here's an example showing how ImmutableMap can be used as a default value in a dataclass, which is particularly useful with JAX:
import functools import jax import jax.numpy as jnp from dataclasses import dataclass from xmmutablemap import ImmutableMap @functools.partial( jax.tree_util.register_dataclass, data_fields=["params"], meta_fields=["batch_size"] ) @dataclass(frozen=True) class Config: """Configuration with immutable default parameters.""" # This works! ImmutableMap is immutable and hashable params: ImmutableMap[str, float] = ImmutableMap( learning_rate=0.001, momentum=0.9, weight_decay=1e-4 ) batch_size: int = 32 # JAX can safely transform functions using this dataclass @jax.jit def train_step(config: Config, data: jnp.ndarray) -> jnp.ndarray: """Example training step that uses config parameters.""" lr = config.params["learning_rate"] return data * lr # This works perfectly config = Config() data = jnp.array([1.0, 2.0, 3.0]) result = train_step(config, data) print(f"Result: {result}") # Result: [0.001 0.002 0.003]- Immutability: Once created,
ImmutableMapcannot be modified, preventing accidental mutations that could break JAX's functional programming model - Hashability: JAX can safely cache and memoize functions that use
ImmutableMapinstances - PyTree Support:
ImmutableMapis registered as a JAX PyTree, so it works seamlessly with JAX transformations likejit,grad,vmap, etc. - Safe Defaults: Can be used as default values in dataclasses without the typical pitfalls of mutable defaults
We welcome contributions!