Skip to content

Commit e553424

Browse files
zhewenlywang96
andauthored
[CI/Build] Refactor Attention backend for test_prefix_prefill from xformers to SDPA (vllm-project#28424)
Signed-off-by: zhewenli <zhewenli@meta.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
1 parent 5a1271d commit e553424

File tree

1 file changed

+194
-116
lines changed

1 file changed

+194
-116
lines changed

tests/kernels/attention/test_prefix_prefill.py

Lines changed: 194 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88

99
import pytest
1010
import torch
11-
from xformers import ops as xops
12-
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
11+
import torch.nn.functional as F
1312

14-
from tests.kernels.utils import make_alibi_bias
1513
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
1614
from vllm.attention.ops.prefix_prefill import context_attention_fwd
1715
from vllm.platforms import current_platform
@@ -28,6 +26,74 @@
2826
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
2927

3028

29+
def create_causal_attention_mask_for_sdpa(
30+
query_lens: list[int],
31+
seq_lens: list[int],
32+
sliding_window: int = 0,
33+
device: torch.device = None,
34+
dtype: torch.dtype = None,
35+
) -> torch.Tensor:
36+
total_queries = sum(query_lens)
37+
total_keys = sum(seq_lens)
38+
39+
# Create a mask filled with -inf
40+
mask = torch.full(
41+
(total_queries, total_keys), float("-inf"), device=device, dtype=dtype
42+
)
43+
44+
query_start = 0
45+
key_start = 0
46+
47+
for query_len, seq_len in zip(query_lens, seq_lens):
48+
query_end = query_start + query_len
49+
key_end = key_start + seq_len
50+
q_indices = torch.arange(query_len, device=device)
51+
k_indices = torch.arange(seq_len, device=device)
52+
q_pos_in_seq = seq_len - query_len + q_indices
53+
54+
valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None]
55+
56+
if sliding_window > 0:
57+
valid_mask &= k_indices[None, :] >= (
58+
q_pos_in_seq[:, None] - sliding_window + 1
59+
)
60+
61+
mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0
62+
63+
query_start = query_end
64+
key_start = key_end
65+
66+
return mask
67+
68+
69+
def create_alibi_causal_mask(
70+
query_len: int,
71+
seq_len: int,
72+
alibi_slopes: torch.Tensor,
73+
device: torch.device,
74+
dtype: torch.dtype,
75+
) -> torch.Tensor:
76+
query_pos = torch.arange(
77+
seq_len - query_len, seq_len, device=device, dtype=torch.float32
78+
)
79+
key_pos = torch.arange(seq_len, device=device, dtype=torch.float32)
80+
81+
rel_pos = key_pos[None, :] - query_pos[:, None]
82+
83+
# Apply ALiBi slopes: [num_heads, query_len, seq_len]
84+
alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :]
85+
alibi_bias = alibi_bias.to(dtype)
86+
87+
# Apply causal mask: prevent attending to future positions
88+
# causal_mask[i, j] = True if key_pos[j] <= query_pos[i]
89+
causal_mask = key_pos[None, :] <= query_pos[:, None]
90+
alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf"))
91+
92+
# Add batch dimension: [1, num_heads, query_len, seq_len]
93+
# SDPA expects batch dimension even for single sequences
94+
return alibi_bias.unsqueeze(0)
95+
96+
3197
@pytest.mark.parametrize("num_heads", NUM_HEADS)
3298
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
3399
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@@ -52,6 +118,13 @@ def test_contexted_kv_attention(
52118
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
53119
)
54120

121+
if (
122+
current_platform.is_rocm()
123+
and op is chunked_prefill_paged_decode
124+
and kv_cache_dtype == "fp8_e5m2"
125+
):
126+
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
127+
55128
current_platform.seed_everything(0)
56129
torch.set_default_device(device)
57130

@@ -96,16 +169,16 @@ def test_contexted_kv_attention(
96169
)
97170
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
98171
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
99-
values = torch.arange(0, cache_size, dtype=torch.long)
172+
values = torch.arange(0, cache_size, dtype=torch.int32)
100173
values = values[torch.randperm(cache_size)]
101174
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
102-
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
103-
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
104-
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
175+
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
176+
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
177+
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
105178
max_input_len = MAX_SEQ_LEN
106179
# copy kv to cache
107180
b_seq_start_loc = torch.cumsum(
108-
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
181+
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
109182
)
110183
for i in range(BS):
111184
for j in range(query_lens[i]):
@@ -189,56 +262,57 @@ def test_contexted_kv_attention(
189262

190263
scale = float(1.0 / (head_size**0.5))
191264

192-
attn_op = xops.fmha.cutlass.FwOp()
265+
# Reshape for SDPA: (seq_len, num_heads, head_size) ->
266+
# (1, num_heads, seq_len, head_size)
267+
query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size)
268+
query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape(
269+
1, num_heads, num_tokens, head_size
270+
)
193271

194-
if num_kv_heads != num_heads:
195-
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
196-
# project the key and value tensors to the desired number of
197-
# heads.
198-
#
199-
# see also: vllm/model_executor/layers/attention.py
200-
query = query.view(
201-
query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1]
202-
)
203-
key = key[:, :, None, :].expand(
204-
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
205-
)
206-
value = value[:, :, None, :].expand(
207-
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
208-
)
209-
query = query.unsqueeze(0)
210-
key = key.unsqueeze(0)
211-
value = value.unsqueeze(0)
272+
# Expand key and value for GQA/MQA to match query heads
273+
key_sdpa = key[:, :, None, :].expand(
274+
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
275+
)
276+
key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape(
277+
1, num_heads, sum(seq_lens), head_size
278+
)
212279

213-
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
214-
query_lens, seq_lens
280+
value_sdpa = value[:, :, None, :].expand(
281+
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
215282
)
216-
if sliding_window > 0:
217-
attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window)
218-
output_ref = xops.memory_efficient_attention_forward(
219-
query,
220-
key,
221-
value,
222-
attn_bias=attn_bias,
223-
p=0.0,
283+
value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape(
284+
1, num_heads, sum(seq_lens), head_size
285+
)
286+
287+
attn_mask = create_causal_attention_mask_for_sdpa(
288+
query_lens, seq_lens, sliding_window, device=device, dtype=dtype
289+
)
290+
291+
output_ref = F.scaled_dot_product_attention(
292+
query_sdpa,
293+
key_sdpa,
294+
value_sdpa,
295+
attn_mask=attn_mask,
296+
dropout_p=0.0,
224297
scale=scale,
225-
op=attn_op,
226298
)
227299
torch.cuda.synchronize()
228300
start_time = time.time()
229-
output_ref = xops.memory_efficient_attention_forward(
230-
query,
231-
key,
232-
value,
233-
attn_bias=attn_bias,
234-
p=0.0,
301+
output_ref = F.scaled_dot_product_attention(
302+
query_sdpa,
303+
key_sdpa,
304+
value_sdpa,
305+
attn_mask=attn_mask,
306+
dropout_p=0.0,
235307
scale=scale,
236-
op=attn_op,
237308
)
238309
torch.cuda.synchronize()
239310
end_time = time.time()
240-
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
241-
output_ref = output_ref.reshape(output.shape)
311+
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
312+
313+
# Reshape output back to (num_tokens, num_heads, head_size)
314+
output_ref = output_ref.view(num_heads, num_tokens, head_size)
315+
output_ref = output_ref.permute(1, 0, 2).contiguous()
242316
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
243317
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
244318

@@ -265,6 +339,13 @@ def test_contexted_kv_attention_alibi(
265339
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
266340
)
267341

342+
if (
343+
current_platform.is_rocm()
344+
and op is chunked_prefill_paged_decode
345+
and kv_cache_dtype == "fp8_e5m2"
346+
):
347+
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
348+
268349
current_platform.seed_everything(0)
269350
torch.set_default_device(device)
270351

@@ -331,16 +412,16 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
331412
)
332413
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
333414
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
334-
values = torch.arange(0, cache_size, dtype=torch.long)
415+
values = torch.arange(0, cache_size, dtype=torch.int32)
335416
values = values[torch.randperm(cache_size)]
336417
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
337-
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
338-
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
339-
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
418+
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
419+
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
420+
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
340421
max_input_len = MAX_SEQ_LEN
341422
# copy kv to cache
342423
b_seq_start_loc = torch.cumsum(
343-
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
424+
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
344425
)
345426
for i in range(BS):
346427
for j in range(query_lens[i]):
@@ -423,78 +504,75 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
423504
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
424505
scale = float(1.0 / (head_size**0.5))
425506

426-
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
427-
# we have to pad query tensor before MQA/GQA expanding.
428-
if query.shape[0] != key.shape[0]:
429-
query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype)
430-
query_pad.uniform_(-1e-3, 1e-3)
431-
seq_start = 0
432-
query_start = 0
433-
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
434-
seq_end = seq_start + seq_len
435-
query_end = query_start + query_len
436-
query_pad[seq_start:seq_end, ...] = torch.cat(
437-
[
438-
torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype),
439-
query[query_start:query_end, ...],
440-
],
441-
dim=0,
442-
)
443-
seq_start += seq_len
444-
query_start += query_len
445-
query = query_pad
446-
447-
if num_kv_heads != num_heads:
448-
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
449-
# project the key and value tensors to the desired number of
450-
# heads.
451-
#
452-
# see also: vllm/model_executor/layers/attention.py
453-
key = key[:, :, None, :].expand(
454-
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
455-
)
456-
value = value[:, :, None, :].expand(
457-
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
458-
)
459-
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
460-
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
461-
# codebase. We save some time reshaping alibi matrix at runtime.
462-
key = key.reshape(key.shape[0], -1, key.shape[-1])
463-
value = value.reshape(value.shape[0], -1, value.shape[-1])
464-
query = query.unsqueeze(0)
465-
key = key.unsqueeze(0)
466-
value = value.unsqueeze(0)
467-
468-
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
507+
# Prepare query, key, value for SDPA
508+
# Expand key and value for GQA/MQA to match query heads
509+
key_expanded = key[:, :, None, :].expand(
510+
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
511+
)
512+
value_expanded = value[:, :, None, :].expand(
513+
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
514+
)
515+
469516
output_ref = torch.empty_like(output)
470-
seq_start = 0
471-
query_start = 0
517+
518+
torch.cuda.synchronize()
472519
start_time = time.time()
473-
# Attention with alibi slopes.
474-
# FIXME(DefTruth): Because xformers does not support dynamic sequence
475-
# lengths with custom attention bias, we process each prompt one by
476-
# one. This is inefficient, especially when we have many short prompts.
477-
# modified from: vllm/v1/attention/backends/xformers.py#L343
520+
521+
query_start = 0
522+
key_start = 0
478523
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
479-
seq_end = seq_start + seq_len
480524
query_end = query_start + query_len
481-
out = xops.memory_efficient_attention_forward(
482-
query[:, seq_start:seq_end],
483-
key[:, seq_start:seq_end],
484-
value[:, seq_start:seq_end],
485-
attn_bias=attn_bias[i],
486-
p=0.0,
487-
scale=scale,
525+
key_end = key_start + seq_len
526+
527+
# Get query, key, value for this sequence
528+
q = query[query_start:query_end] # [query_len, num_heads, head_size]
529+
k = key_expanded[
530+
key_start:key_end
531+
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
532+
v = value_expanded[
533+
key_start:key_end
534+
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
535+
536+
# Reshape for SDPA: (batch=1, num_heads, seq_len, head_size)
537+
q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size)
538+
q_sdpa = (
539+
q_sdpa.permute(1, 2, 0, 3)
540+
.reshape(1, num_heads, query_len, head_size)
541+
.contiguous()
542+
)
543+
544+
k_sdpa = (
545+
k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
546+
)
547+
v_sdpa = (
548+
v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
488549
)
489-
out = out.view_as(query[:, seq_start:seq_end]).view(
490-
seq_len, num_heads, head_size
550+
551+
# Create ALiBi causal mask for this sequence using utility function
552+
alibi_mask = create_alibi_causal_mask(
553+
query_len, seq_len, alibi_slopes, device, dtype
554+
)
555+
556+
# Compute attention
557+
out = F.scaled_dot_product_attention(
558+
q_sdpa,
559+
k_sdpa,
560+
v_sdpa,
561+
attn_mask=alibi_mask,
562+
dropout_p=0.0,
563+
scale=scale,
491564
)
492-
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...])
493-
seq_start += seq_len
494-
query_start += query_len
565+
566+
# Reshape output back to [query_len, num_heads, head_size]
567+
out = out.view(num_heads, query_len, head_size).permute(1, 0, 2)
568+
output_ref[query_start:query_end].copy_(out)
569+
570+
query_start = query_end
571+
key_start = key_end
572+
495573
torch.cuda.synchronize()
496574
end_time = time.time()
497-
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
575+
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
498576
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
499577
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
500578

0 commit comments

Comments
 (0)