Run a batch workload with Pathways

For the purpose of this document, batch workloads are defined as JAX workloads that execute to completion and are deployed within the same GKE cluster as the Pathways cluster, specifically alongside the Pathways controller components (IFRT proxy server and Pathways resource manager). Completion of the JAX workload terminates the Pathways cluster components. This guide uses a JAX training workload to demonstrate this.

Before you begin

Make sure you have:

Build a training image using Maxtext

MaxText is an open-source, large language model (LLM) project developed by Google. It's written in JAX and designed to be highly performant and scalable, running efficiently on Google Cloud TPUs and GPUs.

To build a MaxText Docker image by using the latest version of stable JAX from the OSS GitHub repository, run the following command:

git clone https://github.com/AI-Hypercomputer/maxtext cd maxtext/dependencies/scripts gcloud config set project PROJECT bash ./docker_build_dependency_image.sh MODE=stable gcloud auth configure-docker bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

This command pushes the MaxText Kubernetes image to gcr.io/$PROJECT/${USER}_runner. You can use this Docker image to run training on TPUs using Pathways backend.

Run a batch workload using the PathwaysJob API

The following manifest deploys the Pathways components and runs a MaxText workload using the PathwaysJob API. The workload is encapsulated in the main container and exercises train.py.

Copy the following YAML into a file named pathways-job-batch-training.yaml and update the editable values.

apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata:  name: pathways-USER spec:  maxRestarts: MAX_RESTARTS  workers:  - type: TPU_MACHINE_TYPE  topology: TOPOLOGY  numSlices: WORKLOAD_NODEPOOL_COUNT  pathwaysDir: gs://BUCKET_NAME  controller:  deploymentMode: default  template:  spec:  containers:  - name: main  image: gcr.io/PROJECT/USER_runner  command:  - bash  - -c  - |  python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \  base_output_directory=gs://BUCKET_NAME \  run_name=RUN_NAME \  per_device_batch_size=1 \  enable_checkpointing=false \  remat_policy=full \  global_parameter_scale=1 \  steps=20 \  max_target_length=2048 \  use_iota_embed=true \  reuse_example_batch=1 \  dataset_type=synthetic \  attention=flash \  gcs_metrics=True \  enable_single_controller=True

Replace the following:

  • USER : your Google Cloud user ID
  • MAX_RESTARTS : the maximum number of times the Job can be restarted
  • TPU_MACHINE_TYPE : the TPU machine type
  • TOPOLOGY : the TPU v4 or later topology. For more information about TPU versions and supported topologies, see TPU versions
  • WORKLOAD_NODEPOOL_COUNT : the number of node pools used by a Pathways workload
  • BUCKET_NAME : a Cloud Storage bucket for storing temporary files
  • PROJECT : your Google Cloud project ID
  • RUN_NAME : a user-assigned name to identify the workflow run

You can deploy the PathwaysJob YAML as follows:

kubectl apply -f pathways-job-batch-training.yaml 

To view the PathwaysJob instance created by the previous command use:

kubectl get pathwaysjob 

The output should look like this:

NAME AGE pathways-trial 9s

To modify an attribute of the PathwaysJob instance, delete the PathwaysJob instance, modify the YAML and apply it to create a new PathwaysJob instance.

You can follow the progress of your workload by navigating to the Logs Explorer for your JAX container by choosing main under the Container Name filter.

You should see logs like the following which indicates training is progressing. The workload will complete after 30 steps.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888 completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697 completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641 completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547 completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179 completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

To delete the PathwaysJob instance, you can use the following command:

kubectl delete -f pathways-job-batch-training.yaml 

Run a batch workload using XPK

Now you can submit the prebuilt Maxtext docker image using XPK with the same command you used previously.

xpk workload create-pathways \ --workload=WORKLOAD \ --cluster=CLUSTER \ --num-slices=WORKLOAD_NODEPOOL_COUNT \ --tpu-type=TPU_TYPE \ --project=PROJECT \ --zone=ZONE \ --docker-image='gcr.io/PROJECT/USER_runner' \ --command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

Replace the following:

  • WORKLOAD: a unique name to identify your workload
  • CLUSTER: the name of your GKE cluster
  • WORKLOAD_NODEPOOL_COUNT : the maximum number of times the job can be restarted
  • TPU_TYPE: the TPU type specifies the version and size of the Cloud TPU you want to create. For more information about supported TPU types for each TPU version, see TPU versions
  • PROJECT : you Google Cloud project ID
  • ZONE: the zone where you plan to run your workload
  • USER : your Google Cloud user ID
  • RUN_NAME : a user-assigned name to identify the workflow run

You should see output like the following:

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

Use the link in the output from the previous XPK command to follow the progress of your workload. You can filter the logs for your JAX container by choosing jax-tpu under the Container Name filter.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888 completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697 completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641 completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547 completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179 completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

The workload completes after the specified number of steps. If you want to terminate it prematurely, use the following command:

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

What's next