Run TPU workloads in a Docker container
Docker containers make configuring applications easier by combining your code and all needed dependencies in one distributable package. You can run Docker containers within TPU VMs to simplify configuring and sharing your Cloud TPU applications. This document describes how to set up a Docker container for each ML framework supported by Cloud TPU.
Train a PyTorch model in a Docker container
TPU device
- Create Cloud TPU VM - gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base 
- Connect to the TPU VM using SSH - gcloud compute tpus tpu-vm ssh your-tpu-name \ --zone=europe-west4-a 
- Make sure your Google Cloud user has been granted the Artifact Registry Reader role. For more information, see Granting Artifact Registry roles. 
- Start a container in the TPU VM using the nightly PyTorch/XLA image - sudo docker run --net=host -ti --rm --name your-container-name --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \ bash 
- Configure TPU runtime - There are two PyTorch/XLA runtime options: PJRT and XRT. We recommend you use PJRT unless you have a reason to use XRT. To learn more about the different runtime configurations, see the PJRT runtime documentation. - PJRT- export PJRT_DEVICE=TPU - XRT- export XRT_TPU_CONFIG="localservice;0;localhost:51011" 
- Clone the PyTorch XLA repository - git clone --recursive https://github.com/pytorch/xla.git 
- Train ResNet50 - python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1 
When the training script completes, make sure you clean up the resources.
- Type exitto exit from the Docker container
- Type exitto exit from the TPU VM
- Delete the TPU VM - gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a 
TPU slice
When you run PyTorch code on a TPU slice, you must run your code on all TPU workers at the same time. One way to do this is to use the gcloud compute tpus tpu-vm ssh command with the --worker=all and --command flags. The following procedure shows you how to create a Docker image to make setting up each TPU worker easier.
- Create a TPU VM - gcloud compute tpus tpu-vm create your-tpu-name \ --zone=us-central2-b \ --accelerator-type=v4-32 \ --version=tpu-ubuntu2204-base 
- Add the current user to the Docker group - gcloud compute tpus tpu-vm ssh your-tpu-name \ --zone=us-central2-b \ --worker=all \ --command='sudo usermod -a -G docker $USER' 
- Clone the PyTorch XLA repository - gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=us-central2-b \ --command="git clone --recursive https://github.com/pytorch/xla.git" 
- Run the training script in a container on all TPU workers - gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=us-central2-b \ --command="docker run --rm --privileged --net=host -v ~/xla:/xla -e PJRT_DEVICE=TPU us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 python /xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1" - Docker command flags: - --rmremoves the container after its process terminates.
- --privilegedexposes the TPU device to the container.
- --net=hostbinds all of the container's ports to the TPU VM to allow communication between the hosts in the pod.
- -esets environment variables.
 
When the training script completes, make sure you clean up the resources.
Delete the TPU VM using the following command:
gcloud compute tpus tpu-vm delete your-tpu-name \ --zone=us-central2-b
Train a JAX model in a Docker container
TPU device
- Create the TPU VM - gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base 
- Connect to the TPU VM using SSH - gcloud compute tpus tpu-vm ssh your-tpu-name --zone=europe-west4-a 
- Start Docker daemon in TPU VM - sudo systemctl start docker 
- Start Docker container - sudo docker run --net=host -ti --rm --name your-container-name \ --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \ bash 
- Install JAX - pip install jax[tpu] 
- Install FLAX - pip install --upgrade clu git clone https://github.com/google/flax.git pip install --user -e flax 
- Install - tensorflowand- tensorflow-datasetpackages- pip install tensorflow pip install tensorflow-datasets 
- Run the FLAX MNIST training script - cd flax/examples/mnist python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5 
When the training script completes, make sure you clean up the resources.
- Type exitto exit from the Docker container
- Type exitto exit from the TPU VM
- Delete the TPU VM - gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a 
TPU slice
When you run JAX code on a TPU slice, you must run your JAX code on all TPU workers at the same time. One way to do this is to use the gcloud compute tpus tpu-vm ssh command with the --worker=all and --command flags. The following procedure shows you how to create a Docker image to make setting up each TPU worker easier.
- Create a file named - Dockerfilein your current directory and paste the following text- FROM python:3.10 RUN pip install jax[tpu] RUN pip install --upgrade clu RUN git clone https://github.com/google/flax.git RUN pip install --user -e flax RUN pip install tensorflow RUN pip install tensorflow-datasets WORKDIR ./flax/examples/mnist 
- Prepare an Artifact Registry - gcloud artifacts repositories create your-repo \ --repository-format=docker \ --location=europe-west4 --description="Docker repository" \ --project=your-project gcloud artifacts repositories list \ --project=your-project gcloud auth configure-docker europe-west4-docker.pkg.dev 
- Build the Docker image - docker build -t your-image-name . 
- Add a tag to your Docker image before pushing it to the Artifact Registry. For more information on working with Artifact Registry, see Work with container images. - docker tag your-image-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag 
- Push your Docker image to the Artifact Registry - docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag 
- Create a TPU VM - gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base 
- Pull the Docker image from the Artifact Registry on all TPU workers - gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command='sudo usermod -a -G docker ${USER}' - gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet" - gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag" 
- Run the container on all TPU workers - gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag bash" 
- Run the training script on all TPU workers - gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5" 
When the training script completes, make sure you clean up the resources.
- Shut down the container on all workers - gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker kill your-container-name" 
- Delete the TPU VM - gcloud compute tpus tpu-vm delete your-tpu-name \ --zone=europe-west4-a