Train a TensorFlow model with Keras on Google Kubernetes Engine

The following section provides an example of fine-tuning a BERT model for sequence classification using the Hugging Face transformers library with TensorFlow. The dataset is downloaded into a mounted Parallelstore-backed volume, allowing the model training to directly read data from the volume.

Prerequisites

Save the following YAML manifest (parallelstore-csi-job-example.yaml) for your model training Job.

 apiVersion: batch/v1  kind: Job  metadata:  name: parallelstore-csi-job-example  spec:  template:  metadata:  annotations:  gke-parallelstore/cpu-limit: "0"  gke-parallelstore/memory-limit: "0"  spec:  securityContext:  runAsUser: 1000  runAsGroup: 100  fsGroup: 100  containers:  - name: tensorflow  image: jupyter/tensorflow-notebook@sha256:173f124f638efe870bb2b535e01a76a80a95217e66ed00751058c51c09d6d85d  command: ["bash", "-c"]  args:  - |  pip install transformers datasets  python - <<EOF  from datasets import load_dataset  dataset = load_dataset("glue", "cola", cache_dir='/data')  dataset = dataset["train"]  from transformers import AutoTokenizer  import numpy as np  tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")  tokenized_data = tokenizer(dataset["sentence"], return_tensors="np", padding=True)  tokenized_data = dict(tokenized_data)  labels = np.array(dataset["label"])  from transformers import TFAutoModelForSequenceClassification  from tensorflow.keras.optimizers import Adam  model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased")  model.compile(optimizer=Adam(3e-5))  model.fit(tokenized_data, labels)  EOF  volumeMounts:  - name: parallelstore-volume  mountPath: /data  volumes:  - name: parallelstore-volume  persistentVolumeClaim:  claimName: parallelstore-pvc  restartPolicy: Never  backoffLimit: 1 

Apply the YAML manifest to the cluster.

kubectl apply -f parallelstore-csi-job-example.yaml

Check your data loading and model training progress with the following command:

POD_NAME=$(kubectl get pod | grep 'parallelstore-csi-job-example' | awk '{print $1}') kubectl logs -f $POD_NAME -c tensorflow