Skip to content

Conversation

metrizable
Copy link
Contributor

@metrizable metrizable commented Sep 26, 2025

We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0. Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23, and that libtpu version causes pjrt api errors for torch 2.8.0:

pjrt_c_api_helpers.cc:258] Unexpected error status Unexpected PJRT_Plugin_Attributes_Args size: expe cted 32, got 24. The plugin is likely built with a later version than the framework. This plugin is built with PJRT API version 0.75. 

Of particular note, we no longer pre-install tensorflow-tpu as the newer libtpu causes issues finding the TPUs

external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:78] No TPU platform found. Platform manager status: OK 

We also update how we install Python packages via uv for consistency and reproducibility. From a requirements.in file, we first generate a consistent dependency closure via uv pip compile, and then uv pip install the packages from the generated requirements.txt.

Copy link
Contributor

@calderjo calderjo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome thanks!

@calderjo calderjo merged commit 3e031ba into Kaggle:main Sep 26, 2025
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants