Skip to content

Commit 84a7a34

Browse files
drisspgpytorchmergebot
authored andcommitted
[FlexFlash] Specify lowering w/ new BACKEND kernel option (pytorch#168017)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): Align w/ naming convention Pull Request resolved: pytorch#168017 Approved by: https://github.com/Chillee, https://github.com/Skylion007
1 parent c566552 commit 84a7a34

File tree

6 files changed

+281
-38
lines changed

6 files changed

+281
-38
lines changed

test/export/test_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ def forward(self, x):
968968
view_3 = torch.ops.aten.view.default(linear_3, [2, 1, 128, 64]); linear_3 = None
969969
sdpa_score0 = self.sdpa_score0
970970
sdpa_mask0 = self.sdpa_mask0
971-
flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None
971+
flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None
972972
getitem = flex_attention[0]
973973
getitem_1 = flex_attention[1]; getitem_1 = None
974974
getitem_2 = flex_attention[2]; flex_attention = getitem_2 = None

test/inductor/test_flex_attention.py

Lines changed: 184 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from dataclasses import dataclass
1616
from itertools import product
1717
from typing import Optional, TypeVar, Union
18-
from unittest import expectedFailure, skip, skipUnless
18+
from unittest import expectedFailure, mock, skip, skipUnless
1919
from unittest.mock import patch
2020

2121
import torch
@@ -28,6 +28,7 @@
2828
from torch.nn.attention import SDPBackend
2929
from torch.nn.attention.experimental._paged_attention import PagedAttention
3030
from torch.nn.attention.flex_attention import (
31+
_apply_kernel_options,
3132
_create_empty_block_mask,
3233
_DEFAULT_SPARSE_BLOCK_SIZE,
3334
_identity,
@@ -3522,6 +3523,184 @@ def test_kernel_options_argument_is_respected(self, device):
35223523
)
35233524
FileCheck().check("BLOCK_M : tl.constexpr = 16").run(code[0])
35243525

3526+
@supported_platform
3527+
@skip_on_cpu
3528+
def test_backend_auto_matches_triton_large(self, device):
3529+
"""BACKEND='AUTO' should follow Triton heuristics on large shapes."""
3530+
make_tensor = functools.partial(
3531+
torch.randn,
3532+
(2, 2, 256, 64),
3533+
device=device,
3534+
dtype=torch.float16,
3535+
requires_grad=False,
3536+
)
3537+
q, k, v = make_tensor(), make_tensor(), make_tensor()
3538+
3539+
def compile_and_run(kernel_options):
3540+
return run_and_get_code(
3541+
torch.compile(flex_attention, fullgraph=True),
3542+
q,
3543+
k,
3544+
v,
3545+
kernel_options=kernel_options,
3546+
)
3547+
3548+
default_out, default_code = compile_and_run({"BACKEND": "AUTO"})
3549+
triton_out, triton_code = compile_and_run({"BACKEND": "TRITON"})
3550+
3551+
torch.testing.assert_close(default_out, triton_out, atol=0.0, rtol=0.0)
3552+
3553+
default_src = "\n".join(default_code)
3554+
FileCheck().check("flex_attention").check_not("flex_decoding").run(default_src)
3555+
3556+
triton_src = "\n".join(triton_code)
3557+
FileCheck().check("flex_attention").check_not("flex_decoding").run(triton_src)
3558+
3559+
@supported_platform
3560+
@skip_on_cpu
3561+
def test_backend_triton_decode_matches_auto(self, device):
3562+
"""BACKEND='TRITON_DECODE' should match heuristics on decode-friendly shapes."""
3563+
make_tensor = functools.partial(
3564+
torch.randn,
3565+
(1, 2, 64, 64),
3566+
device=device,
3567+
dtype=torch.float16,
3568+
requires_grad=False,
3569+
)
3570+
q, k, v = make_tensor(), make_tensor(), make_tensor()
3571+
3572+
def compile_and_run(kernel_options):
3573+
return run_and_get_code(
3574+
torch.compile(flex_attention, fullgraph=True),
3575+
q,
3576+
k,
3577+
v,
3578+
kernel_options=kernel_options,
3579+
)
3580+
3581+
from torch._inductor.kernel.flex import flex_attention as flex_kernel_mod
3582+
3583+
with mock.patch.object(
3584+
flex_kernel_mod,
3585+
"create_flex_decoding_kernel",
3586+
wraps=flex_kernel_mod.create_flex_decoding_kernel,
3587+
) as decode_kernel:
3588+
default_out, _ = compile_and_run({"BACKEND": "AUTO"})
3589+
self.assertTrue(
3590+
decode_kernel.called,
3591+
"Expected heuristics to dispatch to flex decoding kernel.",
3592+
)
3593+
3594+
with mock.patch.object(
3595+
flex_kernel_mod,
3596+
"create_flex_decoding_kernel",
3597+
wraps=flex_kernel_mod.create_flex_decoding_kernel,
3598+
) as decode_kernel:
3599+
decode_out, _ = compile_and_run({"BACKEND": "TRITON_DECODE"})
3600+
self.assertTrue(
3601+
decode_kernel.called,
3602+
"Expected explicit BACKEND='TRITON_DECODE' to use flex decoding kernel.",
3603+
)
3604+
3605+
self.assertEqual(decode_out.shape, (1, 2, 64, 64))
3606+
torch.testing.assert_close(default_out, decode_out, atol=3e-3, rtol=3e-3)
3607+
3608+
@supported_platform
3609+
@skip_on_cpu
3610+
def test_backend_triton_decode_errors_when_not_supported(self, device):
3611+
"""Requesting decode on unsupported shapes should raise a helpful error."""
3612+
make_tensor = functools.partial(
3613+
torch.randn,
3614+
(1, 2, 256, 64),
3615+
device=device,
3616+
dtype=torch.float16,
3617+
requires_grad=False,
3618+
)
3619+
q, k, v = make_tensor(), make_tensor(), make_tensor()
3620+
3621+
flex_compiled = torch.compile(flex_attention, fullgraph=True)
3622+
with self.assertRaisesRegex(
3623+
RuntimeError,
3624+
r"BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used",
3625+
):
3626+
flex_compiled(q, k, v, kernel_options={"BACKEND": "TRITON_DECODE"})
3627+
3628+
@supported_platform
3629+
@skip_on_cpu
3630+
def test_backend_triton_decode_errors_with_non_power_of_two_gqa(self, device):
3631+
"""BACKEND='TRITON_DECODE' should fail when GQA ratio is not a power of two."""
3632+
q = torch.randn(
3633+
1, 3, 64, 64, device=device, dtype=torch.float16, requires_grad=False
3634+
)
3635+
k = torch.randn(
3636+
1, 1, 64, 64, device=device, dtype=torch.float16, requires_grad=False
3637+
)
3638+
v = torch.randn(
3639+
1, 1, 64, 64, device=device, dtype=torch.float16, requires_grad=False
3640+
)
3641+
3642+
flex_compiled = torch.compile(flex_attention, fullgraph=True)
3643+
with self.assertRaisesRegex(
3644+
RuntimeError,
3645+
r"BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used",
3646+
):
3647+
flex_compiled(
3648+
q,
3649+
k,
3650+
v,
3651+
enable_gqa=True,
3652+
kernel_options={"BACKEND": "TRITON_DECODE"},
3653+
)
3654+
3655+
@supported_platform
3656+
@skip_on_cpu
3657+
def test_backend_rejects_legacy_force_use_flag(self, device):
3658+
"""Combining BACKEND with FORCE_USE_FLEX_ATTENTION should raise an error."""
3659+
make_tensor = functools.partial(
3660+
torch.randn,
3661+
(2, 2, 128, 64),
3662+
device=device,
3663+
dtype=torch.float16,
3664+
requires_grad=False,
3665+
)
3666+
q, k, v = make_tensor(), make_tensor(), make_tensor()
3667+
3668+
flex_compiled = torch.compile(flex_attention, fullgraph=True)
3669+
with self.assertRaisesRegex(
3670+
RuntimeError,
3671+
r"BACKEND cannot be combined with legacy FORCE_USE_FLEX_ATTENTION",
3672+
):
3673+
flex_compiled(
3674+
q,
3675+
k,
3676+
v,
3677+
kernel_options={
3678+
"BACKEND": "TRITON",
3679+
"FORCE_USE_FLEX_ATTENTION": True,
3680+
},
3681+
)
3682+
3683+
@supported_platform
3684+
def test_backend_defaults_and_rejects_invalid(self, device):
3685+
device = torch.device(device)
3686+
query = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32)
3687+
key = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32)
3688+
value = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32)
3689+
3690+
kernel_options = _apply_kernel_options(
3691+
query, key, value, return_lse=True, kernel_options={}
3692+
)
3693+
self.assertEqual(kernel_options["BACKEND"], "AUTO")
3694+
3695+
with self.assertRaisesRegex(ValueError, r"Invalid BACKEND value 'INVALID'"):
3696+
_apply_kernel_options(
3697+
query,
3698+
key,
3699+
value,
3700+
return_lse=True,
3701+
kernel_options={"BACKEND": "INVALID"},
3702+
)
3703+
35253704
@supported_platform
35263705
def test_block_mask_non_divisible(self, device):
35273706
seq = torch.arange(1023, device=device) // 128
@@ -4154,7 +4333,7 @@ def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_
41544333
41554334
score_mod_0 = self.score_mod_0
41564335
mask_fn_0 = self.mask_fn_0
4157-
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
4336+
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
41584337
out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
41594338
return (out,)
41604339
@@ -4190,11 +4369,11 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
41904369
"""\
41914370
class GraphModule(torch.nn.Module):
41924371
def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
4193-
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
4372+
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
41944373
fw_graph0 = self.fw_graph0
41954374
joint_graph0 = self.joint_graph0
41964375
mask_graph0 = self.mask_graph0
4197-
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
4376+
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
41984377
getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
41994378
getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
42004379
getitem_7: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
@@ -4214,7 +4393,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
42144393
42154394
class mask_graph0(torch.nn.Module):
42164395
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
4217-
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
4396+
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
42184397
return full_default
42194398
""".replace( # noqa: B950
42204399
"GPU_TYPE", torch.device(device).type

test/inductor/test_flex_flash.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,15 @@ def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2):
139139
v,
140140
score_mod=score_mod,
141141
block_mask=block_mask,
142-
kernel_options={"force_flash": True},
142+
kernel_options={"BACKEND": "FLASH"},
143143
)
144144
out_triton = compiled_fn(
145145
q,
146146
k,
147147
v,
148148
score_mod=score_mod,
149149
block_mask=block_mask,
150-
kernel_options={"force_flash": False},
150+
kernel_options={"BACKEND": "TRITON"},
151151
)
152152

153153
assert out_flash.shape == out_ref_fp32.shape == out_triton.shape
@@ -200,30 +200,28 @@ def test_flash_attention_unfriendly_seqlen_with_causal(
200200

201201
@dtypes(torch.float16, torch.bfloat16)
202202
def test_flash_attention_kernel_called(self, device, dtype):
203-
"""Test that flash attention kernel is actually called when force_flash=True."""
203+
"""Test that flash attention kernel is actually called when BACKEND='FLASH'."""
204204
q, k, v = create_test_tensors(dtype=dtype, device=device)
205205
compiled_fn = torch.compile(flex_attention)
206206

207-
# Test that flash kernel is called with force_flash=True
207+
# Test that flash kernel is called with BACKEND='FLASH'
208208
with cuda_kernel_profiler("flash_attncute") as prof_result:
209-
compiled_fn(
210-
q, k, v, score_mod=_causal, kernel_options={"force_flash": True}
211-
)
209+
compiled_fn(q, k, v, score_mod=_causal, kernel_options={"BACKEND": "FLASH"})
212210

213211
self.assertTrue(
214212
prof_result["found"],
215213
f"Flash attention kernel not found. Available kernels: {prof_result['kernel_names']}",
216214
)
217215

218-
# Test that flash kernel is NOT called with force_flash=False
216+
# Test that flash kernel is NOT called with BACKEND='TRITON'
219217
with cuda_kernel_profiler("flash_attncute") as prof_result:
220218
compiled_fn(
221-
q, k, v, score_mod=_causal, kernel_options={"force_flash": False}
219+
q, k, v, score_mod=_causal, kernel_options={"BACKEND": "TRITON"}
222220
)
223221

224222
self.assertFalse(
225223
prof_result["found"],
226-
f"Flash attention kernel unexpectedly found when force_flash=False. Kernels: {prof_result['kernel_names']}",
224+
f"Flash attention kernel unexpectedly found when BACKEND='TRITON'. Kernels: {prof_result['kernel_names']}",
227225
)
228226

229227
@dtypes(torch.float16, torch.bfloat16)
@@ -284,8 +282,8 @@ def score_view_mod(score, b, h, q_idx, kv_idx):
284282
flash_vs_triton(q, k, v, score_mod=score_view_mod)
285283

286284
@dtypes(torch.float16, torch.bfloat16)
287-
def test_force_flash_error_with_requires_grad(self, device, dtype):
288-
"""Test that force_flash=True raises error when tensor requires gradients."""
285+
def test_flash_impl_error_with_requires_grad(self, device, dtype):
286+
"""Test that BACKEND='FLASH' raises error when tensor requires gradients."""
289287
q, k, v = create_test_tensors(dtype=dtype, device=device)
290288

291289
bias = torch.randn(4, device=device, dtype=dtype, requires_grad=True)
@@ -296,14 +294,14 @@ def score_mod_with_grad(score, b, h, q_idx, kv_idx):
296294
compiled_fn = torch.compile(flex_attention)
297295
with self.assertRaisesRegex(
298296
RuntimeError,
299-
r"force_flash=True but flash attention cannot be used.*require gradients",
297+
r"BACKEND='FLASH' but flash attention cannot be used.*require gradients",
300298
):
301299
compiled_fn(
302300
q,
303301
k,
304302
v,
305303
score_mod=score_mod_with_grad,
306-
kernel_options={"force_flash": True},
304+
kernel_options={"BACKEND": "FLASH"},
307305
)
308306

309307
@dtypes(torch.float16, torch.bfloat16)

0 commit comments

Comments
 (0)