You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix libtpu version for torch and do not pre-install tensorflow-tpu on TPU. (#1499)
We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0. Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23, and that libtpu version causes pjrt api errors for torch 2.8.0: ``` pjrt_c_api_helpers.cc:258] Unexpected error status Unexpected PJRT_Plugin_Attributes_Args size: expe cted 32, got 24. The plugin is likely built with a later version than the framework. This plugin is built with PJRT API version 0.75. ``` * https://github.com/pytorch/xla/blob/d517649bdef6ab0519c30c704bde8779c8216502/setup.py#L111 * https://github.com/jax-ml/jax/blob/3489529b38d1f11d1e5caf4540775aadd5f2cdda/setup.py#L26 Of particular note, we no longer pre-install `tensorflow-tpu` as the newer libtpu causes issues finding the TPUs ``` external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:78] No TPU platform found. Platform manager status: OK ``` We also update how we install Python packages via `uv` for consistency and reproducibility. From a `requirements.in` file, we first generate a consistent dependency closure via `uv pip compile`, and then `uv pip install` the packages from the generated `requirements.txt`.
0 commit comments