-
Couldn't load subscription status.
- Fork 560
Closed
Labels
Description
🐛 Bug
when using torch_xla version 2.6.0 with the function torch.nn.functional.scaled_dot_product_attention on TPU v5e, it's VERY slow for some reason. (comparison for previous versions below)
To Reproduce
I used the simple code:
import math from functools import partial, wraps import os import torch import timeit import statistics import torch_xla import torch_xla.core.xla_model as xm from torch_xla.amp import autocast def wrap_with_mark_step(func): @wraps(func) def wrapper(*args, **kwargs): res = func(*args, **kwargs) xm.mark_step() return res return wrapper @wrap_with_mark_step def sdpa(query, key, value): return torch.nn.functional.scaled_dot_product_attention(query, key, value) @wrap_with_mark_step def standard_attention(query, key, value): # from sdpa's doc implementation (it's pretty straightforward) scale_factor = 1 / math.sqrt(query.size(-1)) attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight = torch.softmax(attn_weight, dim=-1) return attn_weight @ value def time_attention(attention_fn, query, key, value): # Warm-up for _ in range(100): attention_fn(query, key, value) # Run timings with repeats repeats = 10 number = 100 times = timeit.repeat(partial(attention_fn, query, key, value), repeat=repeats, number=number) mean_time = statistics.mean(times) / number * 1000 # ms per call std_dev = statistics.stdev(times) / number * 1000 # ms per call print(f" {attention_fn.__name__} average time: {mean_time:.3f} ms per call, dev: {std_dev:.3f} ms per call") if __name__ == '__main__': os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1" batch_size = 32 seq_len = 256 num_heads = 16 head_dim = 64 device = xm.xla_device() # Generate input tensors with correct shape: (batch_size, num_heads, seq_len, head_dim) query = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device) key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device) value = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device) assert torch.allclose(sdpa(query, key, value), standard_attention(query, key, value), rtol=1e-02, atol=1e-02) print("Both attention functions produce about the same output, test time performance now") print("No mixed precision:") time_attention(sdpa, query, key, value) time_attention(standard_attention, query, key, value) print("With mixed precision:") with autocast(xm.xla_device(), enabled=True, dtype=torch.bfloat16): time_attention(sdpa, query, key, value) time_attention(standard_attention, query, key, value) print(f"torch version: {torch.__version__}, torch_xla version: {torch_xla.__version__}")Steps to reproduce the behavior:
- I tested i on GKE (with my code in a configmap), that's the yaml I used:
apiVersion: batch/v1 kind: Job metadata: name: test-sdpa-260 spec: backoffLimit: 0 completionMode: Indexed completions: 2 parallelism: 2 template: spec: containers: - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_libtpu_3.10_tpuvm name: test command: - bash args: - -c - python -m my_code.test_sdpa env: - name: PYTHONUNBUFFERED value: "1" - name: PJRT_DEVICE value: "TPU" ports: - containerPort: 12355 - containerPort: 8080 - containerPort: 8431 - containerPort: 8471 - containerPort: 8476 - containerPort: 8477 - containerPort: 8478 - containerPort: 8479 resources: limits: google.com/tpu: '4' requests: google.com/tpu: '4' securityContext: privileged: true volumeMounts: - name: configmap-volume mountPath: /my_code - mountPath: /dev/shm name: shm nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 restartPolicy: Never subdomain: headless-svc tolerations: - effect: NoSchedule key: google.com/tpu operator: Exists volumes: - emptyDir: medium: Memory name: shm - name: configmap-volume configMap: name: my-configmap- then I tested it with other officials docker images (us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_libtpu_3.10_tpuvm and us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_libtpu_3.10_tpuvm)
- the outputs I got are:
Both attention functions produce about the same output, test time performance now No mixed precision: sdpa average time: 0.714 ms per call, dev: 0.000 ms per call standard_attention average time: 0.726 ms per call, dev: 0.001 ms per call With mixed precision: sdpa average time: 0.251 ms per call, dev: 0.000 ms per call standard_attention average time: 0.251 ms per call, dev: 0.000 ms per call torch version: 2.4.0+libtpu, torch_xla version: 2.4.0+libtpu Both attention functions produce about the same output, test time performance now No mixed precision: sdpa average time: 0.722 ms per call, dev: 0.001 ms per call standard_attention average time: 0.726 ms per call, dev: 0.000 ms per call With mixed precision: sdpa average time: 0.253 ms per call, dev: 0.003 ms per call standard_attention average time: 0.239 ms per call, dev: 0.009 ms per call torch version: 2.5.1+libtpu, torch_xla version: 2.5.1+libtpu Both attention functions produce about the same output, test time performance now No mixed precision: sdpa average time: 116.679 ms per call, dev: 0.447 ms per call standard_attention average time: 0.726 ms per call, dev: 0.000 ms per call With mixed precision: sdpa average time: 106.377 ms per call, dev: 0.511 ms per call standard_attention average time: 0.347 ms per call, dev: 0.003 ms per call torch version: 2.6.0+libtpu, torch_xla version: 2.6.0+libtpu more then x100 times slower in version 2.6.0 regarding previous versions (and the straightforward implementation [which also was degradated in the mixed percision])!
This is very disturbing considering attention is the basic operator for a lot of modern architectures.
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU v5e
- torch_xla version: 2.4.0, 2.5.1, 2.6.0