Orchestrate Multislice workloads using JobSet and Kueue

This tutorial demonstrates how to orchestrate multiple multislice workloads on Google Kubernetes Engine (GKE) for improved resource utilization. You deploy a Jax workload as an example, run it on TPU Multislice, and implement Job queueing with JobSet and Kueue. Kueue determines when Jobs should run based on available resources, quotas, and a hierarchy for fair sharing among teams.

This tutorial is intended for Machine learning (ML) engineers and Platform admins and operators who are interested in the container orchestration capabilities of Kubernetes to train LLMs. To learn more about common roles and example tasks that we reference in Google Cloud content, see Common GKE user roles and tasks.

Before reading this page, ensure that you're familiar with the following:

Prepare the environment

  1. In the Google Cloud console, start a Cloud Shell instance:
    Open Cloud Shell

  2. Set the default environment variables by using the by using the gcloud config set command:

    gcloud config set project PROJECT_ID 

    Replace PROJECT_ID with your Google Cloud project ID.

Autopilot clusters that run version 1.29.2-gke.1521000 or later enable TPUs by default. TPUs on Autopilot clusters are configured in the workload specification. For more information, see the Define your Multislice workloads with JobSets section.

Create a GKE cluster

In Cloud Shell, create a GKE cluster:

Autopilot

gcloud container clusters create-auto multislice-cluster \  --location=CONTROL_PLANE_LOCATION \  --cluster-version 1.29.2-gke.1521000 \  --release-channel rapid 

In this command:

  • The --location flag specifies the Compute Engine region of the control plane of the cluster.
  • The --cluster-version flag specifies Kubernetes version for your cluster.
  • The --release-channel flag specifies the release channel for your cluster. In this case, the rapid channel supports the latest versions available in GKE.

Standard

gcloud container clusters create multislice-cluster \  --location=CONTROL_PLANE_LOCATION 

Replace CONTROL_PLANE_LOCATION with the location in which you want to create your cluster. Ensure it has capacity for the ct5lp-hightpu-4t machine type. Cluster creation might take several minutes.

If you use GKE Autopilot mode, skip to the Create the Kueue resources section. Autopilot clusters that run version 1.29.2-gke.1521000 or later enable TPUs by default.

Create three Standard mode TPU slice node pools

In this section, you create TPU node pools by using the gcloud beta container node-pools create command.

  1. Create the first node pool named nodepool1:

    gcloud beta container node-pools create nodepool1 \  --location=CONTROL_PLANE_LOCATION \  --cluster=multislice-cluster \  --node-locations=NODE_LOCATION \  --machine-type=ct5lp-hightpu-4t \  --tpu-topology=2x4 \  --project=PROJECT_ID 

    Replace NODE_LOCATION with one or more zones in the cluster region in which you want to create the nodes.

  2. Create the second node pool named nodepool2:

    gcloud beta container node-pools create nodepool2 \  --location=CONTROL_PLANE_LOCATION \  --cluster=multislice-cluster \  --node-locations=NODE_LOCATION \  --machine-type=ct5lp-hightpu-4t \  --tpu-topology=2x4 \  --project=PROJECT_ID 
  3. Create the third node pool named nodepool3:

    gcloud beta container node-pools create nodepool3 \  --location=CONTROL_PLANE_LOCATION \  --cluster=multislice-cluster \  --node-locations=NODE_LOCATION \  --machine-type=ct5lp-hightpu-4t \  --tpu-topology=2x4 \  --project=PROJECT_ID 

GKE creates three node pools. Each node pool is a separate TPU slice.

In the preceding steps, you used the gcloud beta container node-pools create command to create the node pools. These commands use the following flags:

  • --node-locations: the comma-separated list of one or more zones where GKE creates the node pools.
  • --machine-type: the type of machine to use for nodes. In this case, you used ct5lp-hightpu-4t. For more information about TPU compatible machine types, use the table in Choose the TPU version.
  • --tpu-topology: the TPU topology to use for the node pool. In this case, you used 2x4. For more information about TPU topologies, see Choose the TPU topology.

Create the Kueue resources

  1. Create the following kueue.yaml manifest:

    apiVersion: kueue.x-k8s.io/v1beta1 kind: ResourceFlavor metadata:  name: "vlp-24" spec:  nodeLabels:  cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice  cloud.google.com/gke-tpu-topology: 2x4 --- apiVersion: kueue.x-k8s.io/v1beta1 kind: ClusterQueue metadata:  name: "cluster-queue" spec:  namespaceSelector: {}  queueingStrategy: BestEffortFIFO  resourceGroups:  - coveredResources: ["google.com/tpu"]  flavors:  - name: "vlp-24"  resources:  - name: "google.com/tpu"  nominalQuota: 24 --- apiVersion: kueue.x-k8s.io/v1beta1 kind: LocalQueue metadata:  namespace: default  name: multislice-queue spec:  clusterQueue: cluster-queue 
  2. Apply the kueue.yaml manifest:

    kubectl apply -f kueue.yaml 

GKE creates the following Kueue resources:

  • ResourceFlavor: An abstraction of the resources in a cluster. In this example, GKE creates three TPU slices with 2x4 topology. Each TPU slice has a 2x4 topology with 8 chips (24 TPU chips in total).
  • ClusterQueue: A global queue managing workloads and cluster resources.
  • LocalQueue: Groups closely related workloads that are typically run by a single tenant (user). Each LocalQueue points to a ClusterQueue from which resources are allocated to run its workloads. A Kueue Workload is an abstraction representing a batch workload, in this case, each workload is a JobSet.

Define your Multislice workloads with JobSets

In this section, you create three JobSets. A Jobset is a workload API that lets you manage a group of Kubernetes Jobs as a unit. The most common use case for a JobSet is distributed training, but you can also use it to run batch workloads.

The following JobSets run a Jax workload which outputs the global number of TPU chips in the slice, then sleeps for 60 seconds to simulate some model training time, then exits.

  1. Install the JobSet API in your cluster:

    VERSION=v0.8.1 kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/$VERSION/manifests.yaml 
  2. Create the following jobsets-multislice.yaml manifest:

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata:  name: multislice-1slice  labels:  kueue.x-k8s.io/queue-name: multislice-queue  annotations:  alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec:  failurePolicy:  maxRestarts: 4  replicatedJobs:  - name: slice  replicas: 1  template:  spec:  parallelism: 2  completions: 2  backoffLimit: 0  template:  spec:  nodeSelector:  cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice  cloud.google.com/gke-tpu-topology: 2x4  containers:  - name: jax-tpu  image: python:3.8  ports:  - containerPort: 8471  - containerPort: 8080  command:  - bash  - -c  - |  pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html  python -c 'import jax; print("Global device count:", jax.device_count())'  resources:  limits:  google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata:  name: multislice-2slice  labels:  kueue.x-k8s.io/queue-name: multislice-queue  annotations:  alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec:  failurePolicy:  maxRestarts: 4  replicatedJobs:  - name: slice  replicas: 2  template:  spec:  parallelism: 2  completions: 2  backoffLimit: 0  template:  spec:  nodeSelector:  cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice  cloud.google.com/gke-tpu-topology: 2x4  containers:  - name: jax-tpu  image: python:3.8  ports:  - containerPort: 8471  - containerPort: 8080  command:  - bash  - -c  - |  pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html  python -c 'import jax; print("Global device count:", jax.device_count())'  sleep 60  resources:  limits:  google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata:  name: multislice-3slice  labels:  kueue.x-k8s.io/queue-name: multislice-queue  annotations:  alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec:  failurePolicy:  maxRestarts: 4  replicatedJobs:  - name: slice  replicas: 3  template:  spec:  parallelism: 2  completions: 2  backoffLimit: 0  template:  spec:  nodeSelector:  cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice  cloud.google.com/gke-tpu-topology: 2x4  containers:  - name: jax-tpu  image: python:3.8  ports:  - containerPort: 8471  - containerPort: 8080  command:  - bash  - -c  - |  sleep 60  resources:  limits:  google.com/tpu: 4 

    Standard

    apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata:  name: multislice-1slice  labels:  kueue.x-k8s.io/queue-name: multislice-queue  annotations:  alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec:  failurePolicy:  maxRestarts: 4  replicatedJobs:  - name: slice  replicas: 1  template:  spec:  parallelism: 2  completions: 2  backoffLimit: 0  template:  spec:  hostNetwork: true  dnsPolicy: ClusterFirstWithHostNet  nodeSelector:  cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice  cloud.google.com/gke-tpu-topology: 2x4  containers:  - name: jax-tpu  image: python:3.8  ports:  - containerPort: 8471  - containerPort: 8080  securityContext:  privileged: true  command:  - bash  - -c  - |  pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html  python -c 'import jax; print("Global device count:", jax.device_count())'  resources:  limits:  google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata:  name: multislice-2slice  labels:  kueue.x-k8s.io/queue-name: multislice-queue  annotations:  alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec:  failurePolicy:  maxRestarts: 4  replicatedJobs:  - name: slice  replicas: 2  template:  spec:  parallelism: 2  completions: 2  backoffLimit: 0  template:  spec:  hostNetwork: true  dnsPolicy: ClusterFirstWithHostNet  nodeSelector:  cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice  cloud.google.com/gke-tpu-topology: 2x4  containers:  - name: jax-tpu  image: python:3.8  ports:  - containerPort: 8471  - containerPort: 8080  securityContext:  privileged: true  command:  - bash  - -c  - |  pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html  python -c 'import jax; print("Global device count:", jax.device_count())'  sleep 60  resources:  limits:  google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata:  name: multislice-3slice  labels:  kueue.x-k8s.io/queue-name: multislice-queue  annotations:  alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec:  failurePolicy:  maxRestarts: 4  replicatedJobs:  - name: slice  replicas: 3  template:  spec:  parallelism: 2  completions: 2  backoffLimit: 0  template:  spec:  hostNetwork: true  dnsPolicy: ClusterFirstWithHostNet  nodeSelector:  cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice  cloud.google.com/gke-tpu-topology: 2x4  containers:  - name: jax-tpu  image: python:3.8  ports:  - containerPort: 8471  - containerPort: 8080  securityContext:  privileged: true  command:  - bash  - -c  - |  sleep 60  resources:  limits:  google.com/tpu: 4 
  3. Apply the jobsets-multislice.yaml manifest:

    kubectl apply -f jobsets-multislice.yaml 

GKE creates the Jobs with the following resource requests:

  • The multislice-1slice JobSet creates one Job that requires one TPU slice in total.
  • The multislice-2slice JobSet creates two Jobs that require two TPU slices in total.
  • The multislice-3slice JobSet creates three Jobs that require three TPU slice in total.

Because the cluster only has three TPU slices, not all JobSets can run at once. When Kueue enqueues all three of multislice-3slice JobSets, its Jobs run alone to completion. The multislice-1slice and multislice-2slice wait and run together afterwards.

Verify Kueue admitted the workloads

  1. Check the enqueued workloads in Kueue:

    kubectl get workloads 

    The output is similar to the following:

    NAME QUEUE ADMITTED BY AGE jobset-multislice-1slice-2530a multislice-queue 3s jobset-multislice-2slice-ffb02 multislice-queue 4s jobset-multislice-3slice-8c695 multislice-queue cluster-queue 10s 

Kueue enqueues one or more workloads, depending on the TPU resources they require.

Monitor the workloads

JobSet and node pool observability metrics and dashboards in the Google Cloud console are generally available.

Dashboards

To view the status of your TPU multi-host node pools on GKE, go to the GKE TPU Node Pool Status dashboard provided by Cloud Monitoring:

Go to GKE TPU Node Pool Status

For more information, see Monitor health metrics for TPU nodes and node pools.

In the Kubernetes Engine AI/ML page in the Google Cloud console, the AI deployment > Jobs tab displays a JobSet monitoring dashboard with comprehensive information about the health and performance of JobSets and their underlying infrastructure, such as JobSet status, replica readiness, replica state. The dashboard also includes infrastructure metrics including CPU, GPU, TPU, memory, and storage metrics. For more information, see Monitor JobSet health with metrics.

Monitor which Pods are running

kubectl get pods 

The output is similar to the following:

NAME READY STATUS RESTARTS AGE multislice-1slice-slice-0-0-pf2ll 1/1 Running 0 1s multislice-1slice-slice-0-1-55g62 1/1 Running 0 1s multislice-2slice-slice-0-0-f4hf7 1/1 Running 0 3s multislice-2slice-slice-0-1-c8kv7 1/1 Running 0 3s multislice-2slice-slice-1-0-7h46t 1/1 Running 0 3s multislice-2slice-slice-1-1-lj9hb 1/1 Running 0 3s multislice-3slice-slice-0-0-wzq9t 0/1 Completed 0 2m31s multislice-3slice-slice-0-1-zf4dp 0/1 Completed 0 2m30s multislice-3slice-slice-1-0-hbfn5 0/1 Completed 0 2m31s multislice-3slice-slice-1-1-45fgl 0/1 Completed 0 2m30s multislice-3slice-slice-2-0-wjbp4 0/1 Completed 0 2m30s multislice-3slice-slice-2-1-lwnvs 0/1 Completed 0 2m30s 

See that GKE scheduled, created, and ran the Pods for multislice-3slice first. Then, GKE ran the Pods from multislice-1slice and multislice-2slice JobSets.

Monitor JobSet health with metrics

To understand whether a JobSet is running as expected, or to infer whether it was interrupted, you can use Prometheus metrics from the JobSet metrics package, such as kube_jobset_succeeded_replicas.

Note that Jobset health metrics are supported only in GKE version 1.32.1-gke.135700 or later. JobSet health metrics are enabled by default in newly created clusters with supported versions. For existing clusters that are upgraded to supported versions, customers need to manually enable the JobSet metrics package. For more information, refer to the documentation.

For this tutorial, check the JobSet completion with this PromQL query:

kube_jobset_succeeded_replicas{  cluster="multislice-cluster",  jobset_name=~"mulitslice-.*"} 

Monitor JobSet uptime, times to recover (TTR), and times between interruptions (TBI)

The following metrics are useful for monitoring the availability of a JobSet:

  • kubernetes.io/jobset/uptime: total time the JobSet has been available.
  • kubernetes.io/jobset/times_to_recover: Distribution of recovery period for a JobSet. Each sample indicates a single recovery event from a downtime period for the JobSet.
  • kubernetes.io/jobset/times_between_interruptions: Distribution of the interval between the end of the previous interruption and the beginning of the current interruption for a JobSet. Each sample indicates a single duration between the previous and the current interruption.

These metrics apply to JobSets that have exactly one GPU or TPU replicated job. The calculation of the metrics is based only on the availability of that single replicated job. The metrics are supported in all GKE versions.

To view uptime for the JobSets that you used in this tutorial, run the following PromQL query:

avg_over_time(  kubernetes_io:jobset_uptime{  monitored_resource="k8s_entity", entity_type="jobset",  entity_name=~"multislice-.*",cluster_name="multislice-cluster"}[${__interval}]) 

To view the TBI distributions for the JobSets from this tutorial, run the following PromQL query:

histogram_quantile(0.50,  sum_over_time(  kubernetes_io:jobset_times_between_interruptions_bucket{  monitored_resource="k8s_entity",entity_type="jobset",  entity_name=~"multislice-.*",cluster_name="multislice-cluster"}[${__interval}])) 

You can extend the interval of the query to a longer time horizon, such as 7 days, and compute the mean time between interruptions (MTBI) over this period:

sum(sum_over_time(  kubernetes_io:jobset_times_between_interruptions_sum{  monitored_resource="k8s_entity",entity_type="jobset",  entity_name=~"multislice-.*",cluster_name="multislice-cluster"}[${__interval}])) / sum(sum_over_time(  kubernetes_io:jobset_times_between_interruptions_count{  monitored_resource="k8s_entity",entity_type="jobset",  entity_name=~"multislice-.*",cluster_name="multislice-cluster"}[${__interval}])) 

To view the TTR distributions, you can run the following PromQL queries:

histogram_quantile(0.50,  sum_over_time(  kubernetes_io:jobset_times_to_recover_bucket{  monitored_resource="k8s_entity",entity_type="jobset",  entity_name=~"multislice-.*",cluster_name="multislice-cluster"}[${__interval}])) 

After you increase the query interval to a longer time horizon, such as 7 days, you can compute the mean time to recover (MTTR) over this period:

sum(sum_over_time(  kubernetes_io:jobset_times_to_recover_sum{  monitored_resource="k8s_entity",entity_type="jobset",  entity_name=~"multislice-.*",cluster_name="multislice-cluster"}[${__interval}])) / sum(sum_over_time(  kubernetes_io:jobset_times_to_recover_count{  monitored_resource="k8s_entity",entity_type="jobset",  entity_name=~"multislice-.*",cluster_name="multislice-cluster"}[${__interval}])) 

Enable Kueue workload priorities and preemption

Optionally, you can assign Kueue workloads priorities which determine the order in which enqueued workloads are admitted by Kueue.

  1. Update your ClusterQueue to have a preemption policy:

    apiVersion: kueue.x-k8s.io/v1beta1 kind: ResourceFlavor metadata:  name: "vlp-24" spec:  nodeLabels:  cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice  cloud.google.com/gke-tpu-topology: 2x4 --- apiVersion: kueue.x-k8s.io/v1beta1 kind: ClusterQueue metadata:  name: "cluster-queue" spec:  namespaceSelector: {}  resourceGroups:  - coveredResources: ["google.com/tpu"]  flavors:  - name: "vlp-24"  resources:  - name: "google.com/tpu"  nominalQuota: 24  preemption:  reclaimWithinCohort: Any  withinClusterQueue: LowerPriority --- apiVersion: kueue.x-k8s.io/v1beta1 kind: LocalQueue metadata:  namespace: default  name: multislice-queue spec:  clusterQueue: cluster-queue 
  2. Create a PriorityClass for each distinct priority level you want to assign to workloads:

    apiVersion: scheduling.k8s.io/v1 kind: PriorityClass metadata:  name: low-priority value: 100 globalDefault: false description: "This low priority class should be used for some Pods only." 
  3. Assign the priorityClassName to your JobSet:

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata:  name: low-priority  labels:  kueue.x-k8s.io/queue-name: multislice-queue  annotations:  alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec:  failurePolicy:  maxRestarts: 4  replicatedJobs:  - name: slice  replicas: 1  template:  spec:  parallelism: 2  completions: 2  backoffLimit: 0  template:  spec:  nodeSelector:  cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice  cloud.google.com/gke-tpu-topology: 2x4  priorityClassName: low-priority  containers:  - name: jax-tpu  image: python:3.8  ports:  - containerPort: 8471  - containerPort: 8080  command:  - bash  - -c  - |  sleep 60  resources:  limits:  google.com/tpu: 4 # Number of TPU chips per worker 

    Standard

    apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata:  name: low-priority  labels:  kueue.x-k8s.io/queue-name: multislice-queue  annotations:  alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec:  failurePolicy:  maxRestarts: 4  replicatedJobs:  - name: slice  replicas: 1  template:  spec:  parallelism: 2  completions: 2  backoffLimit: 0  template:  spec:  hostNetwork: true  dnsPolicy: ClusterFirstWithHostNet  nodeSelector:  cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice  cloud.google.com/gke-tpu-topology: 2x4  priorityClassName: low-priority  containers:  - name: jax-tpu  image: python:3.8  ports:  - containerPort: 8471  - containerPort: 8080  securityContext:  privileged: true  command:  - bash  - -c  - |  sleep 60  resources:  limits:  google.com/tpu: 4 # Number of TPU chips per worker 

GKE includes a preemption policy, which defines how Kueue assigns the available resources. The policy specifies that a workload can be preempted if a higher priority workload needs the resources. Workloads with a lower priority value are more likely to be preempted by higher priority workloads.