Skip to content

Commit 4583051

Browse files
authored
Adapt Splash Attention from TorchPrime (#8911)
1 parent aad87e6 commit 4583051

File tree

4 files changed

+709
-2
lines changed

4 files changed

+709
-2
lines changed

test/test_splash_attention.py

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
import logging
2+
import sys
3+
import unittest
4+
5+
import numpy as np
6+
import torch
7+
import torch_xla
8+
import torch_xla.distributed.spmd as xs
9+
from torch_xla import runtime as xr
10+
from torch_xla._internal import tpu
11+
from torch_xla.distributed.spmd import Mesh
12+
from torch_xla.experimental.custom_kernel import flash_attention
13+
14+
from torch_xla.experimental.splash_attention import (
15+
SplashAttentionConfig,
16+
splash_attention,
17+
)
18+
19+
import torch_xla.core.xla_builder as xb
20+
21+
if xr.device_type() == "TPU":
22+
from torch_xla.experimental.custom_kernel import jax_import_guard
23+
24+
jax_import_guard()
25+
import jax
26+
27+
28+
def with_jax_high_precision(func):
29+
30+
def wrapper(*args, **kwargs):
31+
jax.config.update("jax_default_matmul_precision", "highest")
32+
try:
33+
result = func(*args, **kwargs)
34+
finally:
35+
jax.config.update("jax_default_matmul_precision", "default")
36+
return result
37+
38+
return wrapper
39+
40+
41+
class SplashAttentionTest(unittest.TestCase):
42+
43+
@with_jax_high_precision
44+
def setUp(self):
45+
# Common dimensions for all tests. Spalsh attention kernel requires
46+
# NUM_HEADS, SEQ_LEN, HEAD_DIM must >= 128.
47+
self.BATCH_SIZE = 4
48+
# Test GQA with different Q and KV heads.
49+
self.NUM_Q_HEADS = 128
50+
self.NUM_KV_HEADS = 64
51+
self.NUM_HEADS = 128
52+
self.SEQ_LEN = 128
53+
self.HEAD_DIM = 128
54+
self.partition_spec = (("data", "fsdp"), None, None, None)
55+
segment_ids_partition_spec = (("data", "fsdp"), None)
56+
self.config = SplashAttentionConfig(
57+
mesh=str(xs.get_global_mesh()),
58+
qkv_partition_spec=self.partition_spec,
59+
segment_ids_partition_spec=segment_ids_partition_spec,
60+
)
61+
self.q, self.k, self.v, self.q_sa, self.k_sa, self.v_sa = self.ab_comparsion_input_generation(
62+
)
63+
segment_ids = torch.zeros(self.BATCH_SIZE, self.SEQ_LEN).to("xla")
64+
for i in range(self.BATCH_SIZE):
65+
segment_ids[i, :] = i
66+
self.segment_ids_sa = segment_ids.clone().detach()
67+
self.o = flash_attention(
68+
self.q,
69+
self.k,
70+
self.v,
71+
True,
72+
segment_ids,
73+
segment_ids,
74+
partition_spec=self.partition_spec,
75+
mesh=xs.get_global_mesh(),
76+
)
77+
torch_xla.sync()
78+
loss = torch.sum(self.o)
79+
loss.backward()
80+
torch_xla.sync()
81+
self.q_grad, k_grad, v_grad = self.q.grad, self.k.grad, self.v.grad
82+
with torch.no_grad():
83+
self.k_grad = self.maybe_reduce_kv_grad(k_grad)
84+
self.v_grad = self.maybe_reduce_kv_grad(v_grad)
85+
86+
def maybe_repeat_kv(self, hidden_state):
87+
if hidden_state.size(1) == self.NUM_Q_HEADS:
88+
return hidden_state
89+
num_kv_group = self.NUM_Q_HEADS // self.NUM_KV_HEADS
90+
return hidden_state.repeat_interleave(num_kv_group, dim=1)
91+
92+
def maybe_reduce_kv_grad(self, hidden_state_grad):
93+
# For GQA, the kv grad shape is [BATCH_SIZE, NUM_Q_HEADS, SEQ_LEN,
94+
# HEAD_DIM]. We need to convert it back to [BATCH_SIZE, NUM_KV_HEADS,
95+
# SEQ_LEN, HEAD_DIM]. The returned grad should be sum over the kv heads over
96+
# each group to preserve the magnitude of gradients.
97+
if hidden_state_grad.size(1) == self.NUM_KV_HEADS:
98+
return hidden_state_grad
99+
num_kv_group = self.NUM_Q_HEADS // self.NUM_KV_HEADS
100+
return hidden_state_grad.view(
101+
self.BATCH_SIZE,
102+
self.NUM_KV_HEADS,
103+
num_kv_group,
104+
self.SEQ_LEN,
105+
self.HEAD_DIM,
106+
).sum(dim=2)
107+
108+
def ab_comparsion_input_generation(self):
109+
q = torch.randn(
110+
self.BATCH_SIZE,
111+
self.NUM_Q_HEADS,
112+
self.SEQ_LEN,
113+
self.HEAD_DIM,
114+
requires_grad=True).to("xla")
115+
k = torch.randn(
116+
self.BATCH_SIZE,
117+
self.NUM_KV_HEADS,
118+
self.SEQ_LEN,
119+
self.HEAD_DIM,
120+
requires_grad=True,
121+
).to("xla")
122+
v = torch.randn(
123+
self.BATCH_SIZE,
124+
self.NUM_KV_HEADS,
125+
self.SEQ_LEN,
126+
self.HEAD_DIM,
127+
requires_grad=True,
128+
).to("xla")
129+
q.retain_grad()
130+
k.retain_grad()
131+
v.retain_grad()
132+
q_sa = q.clone().detach().requires_grad_(True)
133+
k_sa = k.clone().detach().requires_grad_(True)
134+
v_sa = v.clone().detach().requires_grad_(True)
135+
# Repeat the kv tensors to match the q tensor heads. This is required for flash
136+
k = self.maybe_repeat_kv(k)
137+
k.retain_grad()
138+
v = self.maybe_repeat_kv(v)
139+
v.retain_grad()
140+
torch_xla.sync()
141+
return q, k, v, q_sa, k_sa, v_sa
142+
143+
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
144+
k = self.maybe_repeat_kv(k)
145+
v = self.maybe_repeat_kv(v)
146+
attn_weight = q @ k.transpose(-2, -1)
147+
if attn_mask is not None:
148+
# Masked out the unrelevant parts.
149+
attn_weight = attn_weight.masked_fill(attn_mask,
150+
torch.finfo(attn_weight.dtype).min)
151+
if ab is not None:
152+
attn_weight = attn_weight + ab
153+
attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1)
154+
attn_output = attn_weight @ v
155+
return attn_output
156+
157+
@unittest.skipIf(xr.device_type() != "TPU" or tpu.version() < 3,
158+
"This test only works on TPUv3+.")
159+
@with_jax_high_precision
160+
def test_splash_attention_base(self):
161+
q, k, v, q_sa, k_sa, v_sa = self.ab_comparsion_input_generation()
162+
attention_mask = torch.triu(
163+
torch.ones(self.SEQ_LEN, self.SEQ_LEN), diagonal=1).to("xla")
164+
165+
o = self._attention(q, k, v, attn_mask=attention_mask)
166+
torch_xla.sync()
167+
loss = torch.sum(o)
168+
loss.backward()
169+
q_grad, k_grad, v_grad = q.grad, k.grad, v.grad
170+
torch_xla.sync()
171+
172+
o_sa = splash_attention(q_sa, k_sa, v_sa, self.config.to_json())
173+
torch_xla.sync()
174+
loss_sa = torch.sum(o_sa)
175+
loss_sa.backward()
176+
q_grad_sa, k_grad_sa, v_grad_sa = q_sa.grad, k_sa.grad, v_sa.grad
177+
torch_xla.sync()
178+
179+
with torch.no_grad():
180+
k_grad = self.maybe_reduce_kv_grad(k_grad)
181+
v_grad = self.maybe_reduce_kv_grad(v_grad)
182+
183+
torch.testing.assert_close(o.cpu(), o_sa.cpu(), rtol=1e-3, atol=1e-5)
184+
185+
for org_grad, sa_grad in zip([q_grad, k_grad, v_grad],
186+
[q_grad_sa, k_grad_sa, v_grad_sa],
187+
strict=False):
188+
torch.testing.assert_close(
189+
org_grad.cpu(), sa_grad.cpu(), rtol=1e-4, atol=1e-2)
190+
191+
@unittest.skipIf(xr.device_type() != "TPU" or tpu.version() < 3,
192+
"This test only works on TPUv3+.")
193+
@with_jax_high_precision
194+
def test_splash_attention_sharding(self):
195+
n_devices = xr.global_runtime_device_count()
196+
q = (
197+
torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN,
198+
self.HEAD_DIM).requires_grad_(True).to("xla"))
199+
k = (
200+
torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN,
201+
self.HEAD_DIM).requires_grad_(True).to("xla"))
202+
v = (
203+
torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN,
204+
self.HEAD_DIM).requires_grad_(True).to("xla"))
205+
o = splash_attention(q, k, v, self.config.to_json())
206+
torch_xla.sync()
207+
self.assertEqual(
208+
torch_xla._XLAC._get_xla_sharding_spec(o),
209+
f"{{devices=[{n_devices},1,1,1]<=[{n_devices}]}}",
210+
)
211+
212+
@unittest.skipIf(xr.device_type() != "TPU" or tpu.version() < 3,
213+
"This test only works on TPUv3+.")
214+
@with_jax_high_precision
215+
def test_splash_attention_segment_id(self):
216+
# test the segment id in splash attention against the flash attention kernel
217+
q_sa = self.q_sa.clone().detach().requires_grad_(True)
218+
k_sa = self.k_sa.clone().detach().requires_grad_(True)
219+
v_sa = self.v_sa.clone().detach().requires_grad_(True)
220+
for i in [q_sa, k_sa, v_sa]:
221+
i.retain_grad()
222+
segment_ids_sa = self.segment_ids_sa.clone().detach()
223+
o_sa = splash_attention(
224+
q_sa,
225+
k_sa,
226+
v_sa,
227+
self.config.to_json(),
228+
decoder_segment_ids=segment_ids_sa.to("xla"))
229+
loss_sa = torch.sum(o_sa)
230+
loss_sa.backward()
231+
q_grad_sa, k_grad_sa, v_grad_sa = q_sa.grad, k_sa.grad, v_sa.grad
232+
torch_xla.sync()
233+
torch.testing.assert_close(self.o.cpu(), o_sa.cpu(), rtol=1e-3, atol=1e-5)
234+
for org_grad, sa_grad in zip([self.q_grad, self.k_grad, self.v_grad],
235+
[q_grad_sa, k_grad_sa, v_grad_sa],
236+
strict=False):
237+
torch.testing.assert_close(
238+
org_grad.cpu(), sa_grad.cpu(), rtol=1e-4, atol=1e-2)
239+
240+
@unittest.skipIf(xr.device_type() != "TPU" or tpu.version() < 3,
241+
"This test only works on TPUv3+.")
242+
@with_jax_high_precision
243+
def test_splash_attention_aot_traceable(self):
244+
from functorch.compile import aot_function, make_boxed_func
245+
246+
def compiler(gm, _):
247+
return make_boxed_func(gm)
248+
249+
compiled_splash_attention = aot_function(
250+
splash_attention, fw_compiler=compiler)
251+
252+
q_sa = self.q_sa.clone().detach().requires_grad_(True)
253+
k_sa = self.k_sa.clone().detach().requires_grad_(True)
254+
v_sa = self.v_sa.clone().detach().requires_grad_(True)
255+
for i in [q_sa, k_sa, v_sa]:
256+
i.retain_grad()
257+
segment_ids_sa = self.segment_ids_sa.clone().detach()
258+
o_sa = compiled_splash_attention(
259+
q_sa,
260+
k_sa,
261+
v_sa,
262+
self.config.to_json(),
263+
decoder_segment_ids=segment_ids_sa)
264+
torch_xla.sync()
265+
loss_sa = torch.sum(o_sa)
266+
loss_sa.backward()
267+
torch_xla.sync()
268+
q_grad_sa, k_grad_sa, v_grad_sa = q_sa.grad, k_sa.grad, v_sa.grad
269+
270+
torch.testing.assert_close(self.o.cpu(), o_sa.cpu(), rtol=1e-3, atol=1e-5)
271+
for org_grad, sa_grad in zip([self.q_grad, self.k_grad, self.v_grad],
272+
[q_grad_sa, k_grad_sa, v_grad_sa],
273+
strict=False):
274+
torch.testing.assert_close(
275+
org_grad.cpu(), sa_grad.cpu(), rtol=1e-4, atol=1e-2)
276+
277+
@unittest.skipIf(xr.device_type() != "TPU" or tpu.version() < 3,
278+
"This test only works on TPUv3+.")
279+
@with_jax_high_precision # remove the decorator will cause failure in other tests :)
280+
def test_splash_attention_cache_hit(self):
281+
xb._JAX_TO_XLA_COMPUTATION_CACHE.clear()
282+
starting_cache_misses = xb._jax_to_xla_computation_cache_num_misses()
283+
q = self.q_sa.clone().detach().requires_grad_(True)
284+
k = self.k_sa.clone().detach().requires_grad_(True)
285+
v = self.v_sa.clone().detach().requires_grad_(True)
286+
segment_ids = self.segment_ids_sa.clone().detach()
287+
o = splash_attention(
288+
q,
289+
k,
290+
v,
291+
self.config.to_json(),
292+
decoder_segment_ids=segment_ids.to("xla"))
293+
loss = torch.sum(o)
294+
loss.backward()
295+
torch_xla.sync()
296+
297+
q = self.q_sa.clone().detach().requires_grad_(True)
298+
k = self.k_sa.clone().detach().requires_grad_(True)
299+
v = self.v_sa.clone().detach().requires_grad_(True)
300+
q = q * 2
301+
segment_ids = self.segment_ids_sa.clone().detach()
302+
o = splash_attention(
303+
q,
304+
k,
305+
v,
306+
self.config.to_json(),
307+
decoder_segment_ids=segment_ids.to("xla"))
308+
loss = torch.sum(o)
309+
loss.backward()
310+
torch_xla.sync()
311+
ending_cache_misses = xb._jax_to_xla_computation_cache_num_misses()
312+
# There are 2 misses because we run both forward (+1 miss) and backward (+1
313+
# miss) pass.
314+
self.assertEqual(ending_cache_misses - starting_cache_misses, 2)
315+
316+
317+
if __name__ == "__main__":
318+
logging.getLogger().setLevel(logging.INFO)
319+
torch.set_default_dtype(torch.float32)
320+
torch_xla._XLAC._xla_set_mat_mul_precision("highest")
321+
torch.manual_seed(42)
322+
xr.use_spmd()
323+
num_devices = xr.global_runtime_device_count()
324+
mesh_shape = (num_devices // 2, 2)
325+
device_ids = np.array(range(num_devices))
326+
mesh = Mesh(device_ids, mesh_shape, ("data", "fsdp"))
327+
xs.set_global_mesh(mesh)
328+
test = unittest.main()
329+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
4141
python3 "$TEST_CDIR/test_pallas.py" -v
4242
python3 "$TEST_CDIR/test_pallas_spmd.py"
4343
XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py"
44+
python3 "$TEST_CDIR/test_splash_attention.py"
4445
python3 "$TEST_CDIR/test_profiler_session.py"
4546
python3 "$TEST_CDIR/test_multi_queries_paged_attention_kernel.py"
4647
python3 "$TEST_CDIR/test_ragged_paged_attention_kernel.py"

torch_xla/core/xla_builder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -969,9 +969,15 @@ def _jax_to_xla_computation_cache_num_misses() -> int:
969969
return size
970970

971971

972+
# Be cautious about using cache. JAX config changes
973+
# (https://github.com/jax-ml/jax/blob/3864c4f335d1d236d5367264f3885dfce8721d9d/jax/_src/config.py#L254)
974+
# will not be reflected in the call_jax function argument. However, the config
975+
# will be embedded in the HLO level (e.g., data precision), which potentially
976+
# causes computations with different JAX config to reuse the same HLO.
972977
_JAX_TO_XLA_COMPUTATION_CACHE = WeakKeyDictionary()
973978

974979

980+
@requires_jax
975981
def call_jax(jax_func,
976982
args: tuple[Any, ...],
977983
kwargs: Optional[dict[str, Any]] = None,
@@ -1016,9 +1022,9 @@ def call_jax(jax_func,
10161022
works. If you get tracing overhead, check if `jax_func` is being redefined all the time.
10171023
A common mistake is defining `jax_func` as a local function, e.g. during a training step.
10181024
"""
1019-
1025+
import jax
10201026
kwargs = kwargs or {}
1021-
flattened, _spec = tree_flatten((args, kwargs))
1027+
flattened, _spec = jax.tree.flatten((args, kwargs))
10221028
xla_computation = jax_func_to_xla_computation(jax_func, args, kwargs, name)
10231029
return xla_computation(flattened)
10241030

0 commit comments

Comments
 (0)