Skip to content

Conversation

@tengyifei
Copy link
Collaborator

The main purpose is to replace the clunky manual XlaComputation object caching at
https://github.com/AI-Hypercomputer/torchprime/blob/b0bd47e3c732c56e75d8d2b315f05e06d485dd22/torchprime/torch_xla_models/experimental/custom_kernel.py#L16, and just write xb.call_jax(some_jax_func) and simply avoid repeated tracing there.

We can't reuse the tracing cache in jax.jit because we jit a wrapper and not jax_func. Also as_serialized_hlo_module_proto has overhead itself and it would be nice to avoid calling that repeatedly.

Also we improve xb.call_jax to support non-tensor arguments. These arguments are passed from xb.call_jax to the JAX function unchanged. They are considered "static arguments" and will be baked into the HLO.

Because they are considered static args, we'll re-trace the jax function whenever their values change.

Fixes #8795.

The main purpose is to replace the clunky manual XlaComputation object caching at https://github.com/AI-Hypercomputer/torchprime/blob/b0bd47e3c732c56e75d8d2b315f05e06d485dd22/torchprime/torch_xla_models/experimental/custom_kernel.py#L16, and just write `xb.call_jax(some_jax_func)` and simply avoid repeated tracing there. We can't reuse the tracing cache in `jax.jit` because we jit a wrapper and not `jax_func`. Also `as_serialized_hlo_module_proto` has overhead itself and it would be nice to avoid calling that repeatedly. Also we improve `xb.call_jax` to support non-tensor arguments. These arguments are passed from `xb.call_jax` to the JAX function unchanged. They are considered "static arguments" and will be baked into the HLO. Because they are considered static args, we'll re-trace the jax function whenever their values change. Fixes #8795.
@tengyifei tengyifei marked this pull request as ready for review March 24, 2025 20:28
@tengyifei tengyifei requested review from bhavya01, qihqi and zpcore March 24, 2025 20:31
@tengyifei tengyifei merged commit a3ef52e into master Mar 24, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants