88
99import pytest
1010import torch
11- from xformers import ops as xops
12- from xformers .ops .fmha .attn_bias import BlockDiagonalCausalFromBottomRightMask
11+ import torch .nn .functional as F
1312
14- from tests .kernels .utils import make_alibi_bias
1513from vllm .attention .ops .chunked_prefill_paged_decode import chunked_prefill_paged_decode
1614from vllm .attention .ops .prefix_prefill import context_attention_fwd
1715from vllm .platforms import current_platform
2826OPS = [chunked_prefill_paged_decode , context_attention_fwd ]
2927
3028
29+ def create_causal_attention_mask_for_sdpa (
30+ query_lens : list [int ],
31+ seq_lens : list [int ],
32+ sliding_window : int = 0 ,
33+ device : torch .device = None ,
34+ dtype : torch .dtype = None ,
35+ ) -> torch .Tensor :
36+ total_queries = sum (query_lens )
37+ total_keys = sum (seq_lens )
38+
39+ # Create a mask filled with -inf
40+ mask = torch .full (
41+ (total_queries , total_keys ), float ("-inf" ), device = device , dtype = dtype
42+ )
43+
44+ query_start = 0
45+ key_start = 0
46+
47+ for query_len , seq_len in zip (query_lens , seq_lens ):
48+ query_end = query_start + query_len
49+ key_end = key_start + seq_len
50+ q_indices = torch .arange (query_len , device = device )
51+ k_indices = torch .arange (seq_len , device = device )
52+ q_pos_in_seq = seq_len - query_len + q_indices
53+
54+ valid_mask = k_indices [None , :] <= q_pos_in_seq [:, None ]
55+
56+ if sliding_window > 0 :
57+ valid_mask &= k_indices [None , :] >= (
58+ q_pos_in_seq [:, None ] - sliding_window + 1
59+ )
60+
61+ mask [query_start :query_end , key_start :key_end ][valid_mask ] = 0.0
62+
63+ query_start = query_end
64+ key_start = key_end
65+
66+ return mask
67+
68+
69+ def create_alibi_causal_mask (
70+ query_len : int ,
71+ seq_len : int ,
72+ alibi_slopes : torch .Tensor ,
73+ device : torch .device ,
74+ dtype : torch .dtype ,
75+ ) -> torch .Tensor :
76+ query_pos = torch .arange (
77+ seq_len - query_len , seq_len , device = device , dtype = torch .float32
78+ )
79+ key_pos = torch .arange (seq_len , device = device , dtype = torch .float32 )
80+
81+ rel_pos = key_pos [None , :] - query_pos [:, None ]
82+
83+ # Apply ALiBi slopes: [num_heads, query_len, seq_len]
84+ alibi_bias = alibi_slopes [:, None , None ] * rel_pos [None , :, :]
85+ alibi_bias = alibi_bias .to (dtype )
86+
87+ # Apply causal mask: prevent attending to future positions
88+ # causal_mask[i, j] = True if key_pos[j] <= query_pos[i]
89+ causal_mask = key_pos [None , :] <= query_pos [:, None ]
90+ alibi_bias = alibi_bias .masked_fill (~ causal_mask [None , :, :], float ("-inf" ))
91+
92+ # Add batch dimension: [1, num_heads, query_len, seq_len]
93+ # SDPA expects batch dimension even for single sequences
94+ return alibi_bias .unsqueeze (0 )
95+
96+
3197@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
3298@pytest .mark .parametrize ("num_queries_per_kv" , NUM_QUERIES_PER_KV )
3399@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
@@ -52,6 +118,13 @@ def test_contexted_kv_attention(
52118 "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
53119 )
54120
121+ if (
122+ current_platform .is_rocm ()
123+ and op is chunked_prefill_paged_decode
124+ and kv_cache_dtype == "fp8_e5m2"
125+ ):
126+ pytest .skip ("ROCm custom paged attention does not support fp8_e5m2 KV cache" )
127+
55128 current_platform .seed_everything (0 )
56129 torch .set_default_device (device )
57130
@@ -96,16 +169,16 @@ def test_contexted_kv_attention(
96169 )
97170 k = torch .zeros (sum (query_lens ), num_kv_heads , head_size , dtype = dtype )
98171 v = torch .zeros (sum (query_lens ), num_kv_heads , head_size , dtype = dtype )
99- values = torch .arange (0 , cache_size , dtype = torch .long )
172+ values = torch .arange (0 , cache_size , dtype = torch .int32 )
100173 values = values [torch .randperm (cache_size )]
101174 block_table = values [: BS * max_block_per_request ].view (BS , max_block_per_request )
102- b_seq_len = torch .tensor (seq_lens , dtype = torch .long )
103- b_ctx_len = torch .tensor (ctx_lens , dtype = torch .long )
104- b_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens , dtype = torch .long ), dim = 0 )
175+ b_seq_len = torch .tensor (seq_lens , dtype = torch .int32 )
176+ b_ctx_len = torch .tensor (ctx_lens , dtype = torch .int32 )
177+ b_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens , dtype = torch .int32 ), dim = 0 )
105178 max_input_len = MAX_SEQ_LEN
106179 # copy kv to cache
107180 b_seq_start_loc = torch .cumsum (
108- torch .tensor ([0 ] + seq_lens [:- 1 ], dtype = torch .long ), dim = 0
181+ torch .tensor ([0 ] + seq_lens [:- 1 ], dtype = torch .int32 ), dim = 0
109182 )
110183 for i in range (BS ):
111184 for j in range (query_lens [i ]):
@@ -189,56 +262,57 @@ def test_contexted_kv_attention(
189262
190263 scale = float (1.0 / (head_size ** 0.5 ))
191264
192- attn_op = xops .fmha .cutlass .FwOp ()
265+ # Reshape for SDPA: (seq_len, num_heads, head_size) ->
266+ # (1, num_heads, seq_len, head_size)
267+ query_sdpa = query .view (num_tokens , num_kv_heads , num_queries_per_kv , head_size )
268+ query_sdpa = query_sdpa .permute (1 , 2 , 0 , 3 ).reshape (
269+ 1 , num_heads , num_tokens , head_size
270+ )
193271
194- if num_kv_heads != num_heads :
195- # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
196- # project the key and value tensors to the desired number of
197- # heads.
198- #
199- # see also: vllm/model_executor/layers/attention.py
200- query = query .view (
201- query .shape [0 ], num_kv_heads , num_queries_per_kv , query .shape [- 1 ]
202- )
203- key = key [:, :, None , :].expand (
204- key .shape [0 ], num_kv_heads , num_queries_per_kv , key .shape [- 1 ]
205- )
206- value = value [:, :, None , :].expand (
207- value .shape [0 ], num_kv_heads , num_queries_per_kv , value .shape [- 1 ]
208- )
209- query = query .unsqueeze (0 )
210- key = key .unsqueeze (0 )
211- value = value .unsqueeze (0 )
272+ # Expand key and value for GQA/MQA to match query heads
273+ key_sdpa = key [:, :, None , :].expand (
274+ key .shape [0 ], num_kv_heads , num_queries_per_kv , key .shape [- 1 ]
275+ )
276+ key_sdpa = key_sdpa .permute (1 , 2 , 0 , 3 ).reshape (
277+ 1 , num_heads , sum (seq_lens ), head_size
278+ )
212279
213- attn_bias = BlockDiagonalCausalFromBottomRightMask . from_seqlens (
214- query_lens , seq_lens
280+ value_sdpa = value [:, :, None , :]. expand (
281+ value . shape [ 0 ], num_kv_heads , num_queries_per_kv , value . shape [ - 1 ]
215282 )
216- if sliding_window > 0 :
217- attn_bias = attn_bias .make_local_attention_from_bottomright (sliding_window )
218- output_ref = xops .memory_efficient_attention_forward (
219- query ,
220- key ,
221- value ,
222- attn_bias = attn_bias ,
223- p = 0.0 ,
283+ value_sdpa = value_sdpa .permute (1 , 2 , 0 , 3 ).reshape (
284+ 1 , num_heads , sum (seq_lens ), head_size
285+ )
286+
287+ attn_mask = create_causal_attention_mask_for_sdpa (
288+ query_lens , seq_lens , sliding_window , device = device , dtype = dtype
289+ )
290+
291+ output_ref = F .scaled_dot_product_attention (
292+ query_sdpa ,
293+ key_sdpa ,
294+ value_sdpa ,
295+ attn_mask = attn_mask ,
296+ dropout_p = 0.0 ,
224297 scale = scale ,
225- op = attn_op ,
226298 )
227299 torch .cuda .synchronize ()
228300 start_time = time .time ()
229- output_ref = xops . memory_efficient_attention_forward (
230- query ,
231- key ,
232- value ,
233- attn_bias = attn_bias ,
234- p = 0.0 ,
301+ output_ref = F . scaled_dot_product_attention (
302+ query_sdpa ,
303+ key_sdpa ,
304+ value_sdpa ,
305+ attn_mask = attn_mask ,
306+ dropout_p = 0.0 ,
235307 scale = scale ,
236- op = attn_op ,
237308 )
238309 torch .cuda .synchronize ()
239310 end_time = time .time ()
240- print (f"xformers Time: { (end_time - start_time ) * 1000 :.2f} ms" )
241- output_ref = output_ref .reshape (output .shape )
311+ print (f"PyTorch SDPA Time: { (end_time - start_time ) * 1000 :.2f} ms" )
312+
313+ # Reshape output back to (num_tokens, num_heads, head_size)
314+ output_ref = output_ref .view (num_heads , num_tokens , head_size )
315+ output_ref = output_ref .permute (1 , 0 , 2 ).contiguous ()
242316 atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
243317 torch .testing .assert_close (output , output_ref , atol = atol , rtol = 0 )
244318
@@ -265,6 +339,13 @@ def test_contexted_kv_attention_alibi(
265339 "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
266340 )
267341
342+ if (
343+ current_platform .is_rocm ()
344+ and op is chunked_prefill_paged_decode
345+ and kv_cache_dtype == "fp8_e5m2"
346+ ):
347+ pytest .skip ("ROCm custom paged attention does not support fp8_e5m2 KV cache" )
348+
268349 current_platform .seed_everything (0 )
269350 torch .set_default_device (device )
270351
@@ -331,16 +412,16 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
331412 )
332413 k = torch .zeros (sum (query_lens ), num_kv_heads , head_size , dtype = dtype )
333414 v = torch .zeros (sum (query_lens ), num_kv_heads , head_size , dtype = dtype )
334- values = torch .arange (0 , cache_size , dtype = torch .long )
415+ values = torch .arange (0 , cache_size , dtype = torch .int32 )
335416 values = values [torch .randperm (cache_size )]
336417 block_table = values [: BS * max_block_per_request ].view (BS , max_block_per_request )
337- b_seq_len = torch .tensor (seq_lens , dtype = torch .long )
338- b_ctx_len = torch .tensor (ctx_lens , dtype = torch .long )
339- b_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens , dtype = torch .long ), dim = 0 )
418+ b_seq_len = torch .tensor (seq_lens , dtype = torch .int32 )
419+ b_ctx_len = torch .tensor (ctx_lens , dtype = torch .int32 )
420+ b_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens , dtype = torch .int32 ), dim = 0 )
340421 max_input_len = MAX_SEQ_LEN
341422 # copy kv to cache
342423 b_seq_start_loc = torch .cumsum (
343- torch .tensor ([0 ] + seq_lens [:- 1 ], dtype = torch .long ), dim = 0
424+ torch .tensor ([0 ] + seq_lens [:- 1 ], dtype = torch .int32 ), dim = 0
344425 )
345426 for i in range (BS ):
346427 for j in range (query_lens [i ]):
@@ -423,78 +504,75 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
423504 print (f"triton Time: { (end_time - start_time ) * 1000 :.2f} ms" )
424505 scale = float (1.0 / (head_size ** 0.5 ))
425506
426- # NOTE(DefTruth): In order to reuse _make_alibi_bias function,
427- # we have to pad query tensor before MQA/GQA expanding.
428- if query .shape [0 ] != key .shape [0 ]:
429- query_pad = torch .empty (sum (seq_lens ), num_heads , head_size , dtype = dtype )
430- query_pad .uniform_ (- 1e-3 , 1e-3 )
431- seq_start = 0
432- query_start = 0
433- for i , (query_len , seq_len ) in enumerate (zip (query_lens , seq_lens )):
434- seq_end = seq_start + seq_len
435- query_end = query_start + query_len
436- query_pad [seq_start :seq_end , ...] = torch .cat (
437- [
438- torch .zeros (seq_len - query_len , num_heads , head_size , dtype = dtype ),
439- query [query_start :query_end , ...],
440- ],
441- dim = 0 ,
442- )
443- seq_start += seq_len
444- query_start += query_len
445- query = query_pad
446-
447- if num_kv_heads != num_heads :
448- # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
449- # project the key and value tensors to the desired number of
450- # heads.
451- #
452- # see also: vllm/model_executor/layers/attention.py
453- key = key [:, :, None , :].expand (
454- key .shape [0 ], num_kv_heads , num_queries_per_kv , key .shape [- 1 ]
455- )
456- value = value [:, :, None , :].expand (
457- value .shape [0 ], num_kv_heads , num_queries_per_kv , value .shape [- 1 ]
458- )
459- # [seq, num_kv_heads, num_queries_per_kv, dk]=>
460- # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
461- # codebase. We save some time reshaping alibi matrix at runtime.
462- key = key .reshape (key .shape [0 ], - 1 , key .shape [- 1 ])
463- value = value .reshape (value .shape [0 ], - 1 , value .shape [- 1 ])
464- query = query .unsqueeze (0 )
465- key = key .unsqueeze (0 )
466- value = value .unsqueeze (0 )
467-
468- attn_bias = make_alibi_bias (alibi_slopes , num_kv_heads , dtype , seq_lens )
507+ # Prepare query, key, value for SDPA
508+ # Expand key and value for GQA/MQA to match query heads
509+ key_expanded = key [:, :, None , :].expand (
510+ key .shape [0 ], num_kv_heads , num_queries_per_kv , key .shape [- 1 ]
511+ )
512+ value_expanded = value [:, :, None , :].expand (
513+ value .shape [0 ], num_kv_heads , num_queries_per_kv , value .shape [- 1 ]
514+ )
515+
469516 output_ref = torch .empty_like (output )
470- seq_start = 0
471- query_start = 0
517+
518+ torch . cuda . synchronize ()
472519 start_time = time .time ()
473- # Attention with alibi slopes.
474- # FIXME(DefTruth): Because xformers does not support dynamic sequence
475- # lengths with custom attention bias, we process each prompt one by
476- # one. This is inefficient, especially when we have many short prompts.
477- # modified from: vllm/v1/attention/backends/xformers.py#L343
520+
521+ query_start = 0
522+ key_start = 0
478523 for i , (query_len , seq_len ) in enumerate (zip (query_lens , seq_lens )):
479- seq_end = seq_start + seq_len
480524 query_end = query_start + query_len
481- out = xops .memory_efficient_attention_forward (
482- query [:, seq_start :seq_end ],
483- key [:, seq_start :seq_end ],
484- value [:, seq_start :seq_end ],
485- attn_bias = attn_bias [i ],
486- p = 0.0 ,
487- scale = scale ,
525+ key_end = key_start + seq_len
526+
527+ # Get query, key, value for this sequence
528+ q = query [query_start :query_end ] # [query_len, num_heads, head_size]
529+ k = key_expanded [
530+ key_start :key_end
531+ ] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
532+ v = value_expanded [
533+ key_start :key_end
534+ ] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
535+
536+ # Reshape for SDPA: (batch=1, num_heads, seq_len, head_size)
537+ q_sdpa = q .view (query_len , num_kv_heads , num_queries_per_kv , head_size )
538+ q_sdpa = (
539+ q_sdpa .permute (1 , 2 , 0 , 3 )
540+ .reshape (1 , num_heads , query_len , head_size )
541+ .contiguous ()
542+ )
543+
544+ k_sdpa = (
545+ k .permute (1 , 2 , 0 , 3 ).reshape (1 , num_heads , seq_len , head_size ).contiguous ()
546+ )
547+ v_sdpa = (
548+ v .permute (1 , 2 , 0 , 3 ).reshape (1 , num_heads , seq_len , head_size ).contiguous ()
488549 )
489- out = out .view_as (query [:, seq_start :seq_end ]).view (
490- seq_len , num_heads , head_size
550+
551+ # Create ALiBi causal mask for this sequence using utility function
552+ alibi_mask = create_alibi_causal_mask (
553+ query_len , seq_len , alibi_slopes , device , dtype
554+ )
555+
556+ # Compute attention
557+ out = F .scaled_dot_product_attention (
558+ q_sdpa ,
559+ k_sdpa ,
560+ v_sdpa ,
561+ attn_mask = alibi_mask ,
562+ dropout_p = 0.0 ,
563+ scale = scale ,
491564 )
492- output_ref [query_start :query_end , ...].copy_ (out [seq_len - query_len :, ...])
493- seq_start += seq_len
494- query_start += query_len
565+
566+ # Reshape output back to [query_len, num_heads, head_size]
567+ out = out .view (num_heads , query_len , head_size ).permute (1 , 0 , 2 )
568+ output_ref [query_start :query_end ].copy_ (out )
569+
570+ query_start = query_end
571+ key_start = key_end
572+
495573 torch .cuda .synchronize ()
496574 end_time = time .time ()
497- print (f"xformers Time: { (end_time - start_time ) * 1000 :.2f} ms" )
575+ print (f"PyTorch SDPA Time: { (end_time - start_time ) * 1000 :.2f} ms" )
498576 atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
499577 torch .testing .assert_close (output , output_ref , atol = atol , rtol = 0 )
500578
0 commit comments