1515from dataclasses import dataclass
1616from itertools import product
1717from typing import Optional , TypeVar , Union
18- from unittest import expectedFailure , skip , skipUnless
18+ from unittest import expectedFailure , mock , skip , skipUnless
1919from unittest .mock import patch
2020
2121import torch
2828from torch .nn .attention import SDPBackend
2929from torch .nn .attention .experimental ._paged_attention import PagedAttention
3030from torch .nn .attention .flex_attention import (
31+ _apply_kernel_options ,
3132 _create_empty_block_mask ,
3233 _DEFAULT_SPARSE_BLOCK_SIZE ,
3334 _identity ,
@@ -3522,6 +3523,184 @@ def test_kernel_options_argument_is_respected(self, device):
35223523 )
35233524 FileCheck ().check ("BLOCK_M : tl.constexpr = 16" ).run (code [0 ])
35243525
3526+ @supported_platform
3527+ @skip_on_cpu
3528+ def test_backend_auto_matches_triton_large (self , device ):
3529+ """BACKEND='AUTO' should follow Triton heuristics on large shapes."""
3530+ make_tensor = functools .partial (
3531+ torch .randn ,
3532+ (2 , 2 , 256 , 64 ),
3533+ device = device ,
3534+ dtype = torch .float16 ,
3535+ requires_grad = False ,
3536+ )
3537+ q , k , v = make_tensor (), make_tensor (), make_tensor ()
3538+
3539+ def compile_and_run (kernel_options ):
3540+ return run_and_get_code (
3541+ torch .compile (flex_attention , fullgraph = True ),
3542+ q ,
3543+ k ,
3544+ v ,
3545+ kernel_options = kernel_options ,
3546+ )
3547+
3548+ default_out , default_code = compile_and_run ({"BACKEND" : "AUTO" })
3549+ triton_out , triton_code = compile_and_run ({"BACKEND" : "TRITON" })
3550+
3551+ torch .testing .assert_close (default_out , triton_out , atol = 0.0 , rtol = 0.0 )
3552+
3553+ default_src = "\n " .join (default_code )
3554+ FileCheck ().check ("flex_attention" ).check_not ("flex_decoding" ).run (default_src )
3555+
3556+ triton_src = "\n " .join (triton_code )
3557+ FileCheck ().check ("flex_attention" ).check_not ("flex_decoding" ).run (triton_src )
3558+
3559+ @supported_platform
3560+ @skip_on_cpu
3561+ def test_backend_triton_decode_matches_auto (self , device ):
3562+ """BACKEND='TRITON_DECODE' should match heuristics on decode-friendly shapes."""
3563+ make_tensor = functools .partial (
3564+ torch .randn ,
3565+ (1 , 2 , 64 , 64 ),
3566+ device = device ,
3567+ dtype = torch .float16 ,
3568+ requires_grad = False ,
3569+ )
3570+ q , k , v = make_tensor (), make_tensor (), make_tensor ()
3571+
3572+ def compile_and_run (kernel_options ):
3573+ return run_and_get_code (
3574+ torch .compile (flex_attention , fullgraph = True ),
3575+ q ,
3576+ k ,
3577+ v ,
3578+ kernel_options = kernel_options ,
3579+ )
3580+
3581+ from torch ._inductor .kernel .flex import flex_attention as flex_kernel_mod
3582+
3583+ with mock .patch .object (
3584+ flex_kernel_mod ,
3585+ "create_flex_decoding_kernel" ,
3586+ wraps = flex_kernel_mod .create_flex_decoding_kernel ,
3587+ ) as decode_kernel :
3588+ default_out , _ = compile_and_run ({"BACKEND" : "AUTO" })
3589+ self .assertTrue (
3590+ decode_kernel .called ,
3591+ "Expected heuristics to dispatch to flex decoding kernel." ,
3592+ )
3593+
3594+ with mock .patch .object (
3595+ flex_kernel_mod ,
3596+ "create_flex_decoding_kernel" ,
3597+ wraps = flex_kernel_mod .create_flex_decoding_kernel ,
3598+ ) as decode_kernel :
3599+ decode_out , _ = compile_and_run ({"BACKEND" : "TRITON_DECODE" })
3600+ self .assertTrue (
3601+ decode_kernel .called ,
3602+ "Expected explicit BACKEND='TRITON_DECODE' to use flex decoding kernel." ,
3603+ )
3604+
3605+ self .assertEqual (decode_out .shape , (1 , 2 , 64 , 64 ))
3606+ torch .testing .assert_close (default_out , decode_out , atol = 3e-3 , rtol = 3e-3 )
3607+
3608+ @supported_platform
3609+ @skip_on_cpu
3610+ def test_backend_triton_decode_errors_when_not_supported (self , device ):
3611+ """Requesting decode on unsupported shapes should raise a helpful error."""
3612+ make_tensor = functools .partial (
3613+ torch .randn ,
3614+ (1 , 2 , 256 , 64 ),
3615+ device = device ,
3616+ dtype = torch .float16 ,
3617+ requires_grad = False ,
3618+ )
3619+ q , k , v = make_tensor (), make_tensor (), make_tensor ()
3620+
3621+ flex_compiled = torch .compile (flex_attention , fullgraph = True )
3622+ with self .assertRaisesRegex (
3623+ RuntimeError ,
3624+ r"BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used" ,
3625+ ):
3626+ flex_compiled (q , k , v , kernel_options = {"BACKEND" : "TRITON_DECODE" })
3627+
3628+ @supported_platform
3629+ @skip_on_cpu
3630+ def test_backend_triton_decode_errors_with_non_power_of_two_gqa (self , device ):
3631+ """BACKEND='TRITON_DECODE' should fail when GQA ratio is not a power of two."""
3632+ q = torch .randn (
3633+ 1 , 3 , 64 , 64 , device = device , dtype = torch .float16 , requires_grad = False
3634+ )
3635+ k = torch .randn (
3636+ 1 , 1 , 64 , 64 , device = device , dtype = torch .float16 , requires_grad = False
3637+ )
3638+ v = torch .randn (
3639+ 1 , 1 , 64 , 64 , device = device , dtype = torch .float16 , requires_grad = False
3640+ )
3641+
3642+ flex_compiled = torch .compile (flex_attention , fullgraph = True )
3643+ with self .assertRaisesRegex (
3644+ RuntimeError ,
3645+ r"BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used" ,
3646+ ):
3647+ flex_compiled (
3648+ q ,
3649+ k ,
3650+ v ,
3651+ enable_gqa = True ,
3652+ kernel_options = {"BACKEND" : "TRITON_DECODE" },
3653+ )
3654+
3655+ @supported_platform
3656+ @skip_on_cpu
3657+ def test_backend_rejects_legacy_force_use_flag (self , device ):
3658+ """Combining BACKEND with FORCE_USE_FLEX_ATTENTION should raise an error."""
3659+ make_tensor = functools .partial (
3660+ torch .randn ,
3661+ (2 , 2 , 128 , 64 ),
3662+ device = device ,
3663+ dtype = torch .float16 ,
3664+ requires_grad = False ,
3665+ )
3666+ q , k , v = make_tensor (), make_tensor (), make_tensor ()
3667+
3668+ flex_compiled = torch .compile (flex_attention , fullgraph = True )
3669+ with self .assertRaisesRegex (
3670+ RuntimeError ,
3671+ r"BACKEND cannot be combined with legacy FORCE_USE_FLEX_ATTENTION" ,
3672+ ):
3673+ flex_compiled (
3674+ q ,
3675+ k ,
3676+ v ,
3677+ kernel_options = {
3678+ "BACKEND" : "TRITON" ,
3679+ "FORCE_USE_FLEX_ATTENTION" : True ,
3680+ },
3681+ )
3682+
3683+ @supported_platform
3684+ def test_backend_defaults_and_rejects_invalid (self , device ):
3685+ device = torch .device (device )
3686+ query = torch .randn (1 , 1 , 4 , 8 , device = device , dtype = torch .float32 )
3687+ key = torch .randn (1 , 1 , 4 , 8 , device = device , dtype = torch .float32 )
3688+ value = torch .randn (1 , 1 , 4 , 8 , device = device , dtype = torch .float32 )
3689+
3690+ kernel_options = _apply_kernel_options (
3691+ query , key , value , return_lse = True , kernel_options = {}
3692+ )
3693+ self .assertEqual (kernel_options ["BACKEND" ], "AUTO" )
3694+
3695+ with self .assertRaisesRegex (ValueError , r"Invalid BACKEND value 'INVALID'" ):
3696+ _apply_kernel_options (
3697+ query ,
3698+ key ,
3699+ value ,
3700+ return_lse = True ,
3701+ kernel_options = {"BACKEND" : "INVALID" },
3702+ )
3703+
35253704 @supported_platform
35263705 def test_block_mask_non_divisible (self , device ):
35273706 seq = torch .arange (1023 , device = device ) // 128
@@ -4154,7 +4333,7 @@ def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_
41544333
41554334 score_mod_0 = self.score_mod_0
41564335 mask_fn_0 = self.mask_fn_0
4157- flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
4336+ flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'BACKEND': 'AUTO', ' PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
41584337 out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
41594338 return (out,)
41604339
@@ -4190,11 +4369,11 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
41904369 """\
41914370 class GraphModule(torch.nn.Module):
41924371 def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
4193- full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE ', index=0), pin_memory = False)
4372+ full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda ', index=0), pin_memory = False)
41944373 fw_graph0 = self.fw_graph0
41954374 joint_graph0 = self.joint_graph0
41964375 mask_graph0 = self.mask_graph0
4197- flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
4376+ flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'BACKEND': 'AUTO', ' PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
41984377 getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
41994378 getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
42004379 getitem_7: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
@@ -4214,7 +4393,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
42144393
42154394 class mask_graph0(torch.nn.Module):
42164395 def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
4217- full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE ', index=0), pin_memory = False)
4396+ full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda ', index=0), pin_memory = False)
42184397 return full_default
42194398""" .replace ( # noqa: B950
42204399 "GPU_TYPE" , torch .device (device ).type
0 commit comments