Skip to content

Commit 3e031ba

Browse files
authored
Fix libtpu version for torch and do not pre-install tensorflow-tpu on TPU. (#1499)
We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0. Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23, and that libtpu version causes pjrt api errors for torch 2.8.0: ``` pjrt_c_api_helpers.cc:258] Unexpected error status Unexpected PJRT_Plugin_Attributes_Args size: expe cted 32, got 24. The plugin is likely built with a later version than the framework. This plugin is built with PJRT API version 0.75. ``` * https://github.com/pytorch/xla/blob/d517649bdef6ab0519c30c704bde8779c8216502/setup.py#L111 * https://github.com/jax-ml/jax/blob/3489529b38d1f11d1e5caf4540775aadd5f2cdda/setup.py#L26 Of particular note, we no longer pre-install `tensorflow-tpu` as the newer libtpu causes issues finding the TPUs ``` external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:78] No TPU platform found. Platform manager status: OK ``` We also update how we install Python packages via `uv` for consistency and reproducibility. From a `requirements.in` file, we first generate a consistent dependency closure via `uv pip compile`, and then `uv pip install` the packages from the generated `requirements.txt`.
1 parent acb8bcc commit 3e031ba

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

tpu/Dockerfile

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,39 @@ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
3434
# Additional useful packages should be added in the requirements.txt
3535
# Bring in the requirements.txt and replace variables in it:
3636
RUN apt-get install -y gettext
37-
ADD tpu/requirements.txt /kaggle_requirements.txt
38-
RUN envsubst < /kaggle_requirements.txt > /requirements.txt
37+
ADD tpu/requirements.in /kaggle_requirements.in
38+
RUN envsubst < /kaggle_requirements.in > /requirements.in
3939

4040
# Install uv and then install the requirements:
4141
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
42-
RUN export PATH="${HOME}/.local/bin:${PATH}" && uv pip install --system -r /requirements.txt --prerelease=allow --force-reinstall && \
42+
RUN export PATH="${HOME}/.local/bin:${PATH}" && \
43+
uv pip compile --system --prerelease=allow \
44+
--verbose \
45+
--upgrade \
46+
--find-links=https://storage.googleapis.com/jax-releases/libtpu_releases.html \
47+
--find-links=https://storage.googleapis.com/libtpu-releases/index.html \
48+
--find-links=https://storage.googleapis.com/libtpu-wheels/index.html \
49+
--find-links=https://download.pytorch.org/whl/torch_stable.html \
50+
--emit-find-links \
51+
--no-emit-package pip \
52+
--no-emit-package setuptools \
53+
--output-file /requirements.txt \
54+
/requirements.in && \
55+
uv pip install --system --prerelease=allow --force-reinstall \
56+
-r /requirements.txt && \
57+
uv cache clean && \
4358
/tmp/clean-layer.sh
4459
ENV PATH="~/.local/bin:${PATH}"
4560

46-
# Try to force tensorflow to reliably install without breaking other installed deps
61+
# We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0.
62+
# Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23. However, that
63+
# libtpu causes pjrt api errors for torch 2.8.0. screenshot/5heUtdyaJ4MmR3D
64+
# https://github.com/pytorch/xla/blob/d517649bdef6ab0519c30c704bde8779c8216502/setup.py#L111
65+
# https://github.com/jax-ml/jax/blob/3489529b38d1f11d1e5caf4540775aadd5f2cdda/setup.py#L26
4766
RUN export PATH="${HOME}/.local/bin:${PATH}" && \
48-
uv pip freeze --system > /tmp/constraints.txt && \
49-
uv pip install --system -c /tmp/constraints.txt tensorflow-tpu -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force-reinstall && \
50-
rm /tmp/constraints.txt
67+
uv pip install --system --force-reinstall libtpu==0.0.17 && \
68+
uv cache clean && \
69+
/tmp/clean-layer.sh
5170

5271
# Kaggle Model Hub patches:
5372
ADD patches/kaggle_module_resolver.py /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/kaggle_module_resolver.py

tpu/requirements.txt renamed to tpu/requirements.in

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# TPU Utils
22
tpu-info
33
# Tensorflow packages
4-
tensorflow-tpu==${TENSORFLOW_VERSION}
5-
--find-links https://storage.googleapis.com/libtpu-tf-releases/index.html
4+
# TODO: b/447621961 - re-enable tensorflow-tpu when a compatible libtpu can be found.
5+
tensorflow-cpu==${TENSORFLOW_VERSION}
66
tensorflow_hub
77
tensorflow-io
88
tensorflow-probability
@@ -13,8 +13,7 @@ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TOR
1313
torchaudio==${TORCHAUDIO_VERSION}
1414
torchvision==${TORCHVISION_VERSION}
1515
# Jax packages
16-
jax[tpu]>=0.5.2
17-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
16+
jax[tpu]
1817
distrax
1918
flax
2019
git+https://github.com/deepmind/dm-haiku

0 commit comments

Comments
 (0)