Fix libtpu version for torch and do not pre-install tensorflow-tpu on TPU. #1499
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0. Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23, and that libtpu version causes pjrt api errors for torch 2.8.0:
Of particular note, we no longer pre-install
tensorflow-tpu
as the newer libtpu causes issues finding the TPUsWe also update how we install Python packages via
uv
for consistency and reproducibility. From arequirements.in
file, we first generate a consistent dependency closure viauv pip compile
, and thenuv pip install
the packages from the generatedrequirements.txt
.