Skip to content

Commit 97ca0b4

Browse files
remi-orMekkCybergithub-actions[bot]
authored
Fix flash-attn for paged_attention when no kernels (#41078)
* Fix non-kernels flash attention paged implementation * Cover all cases * Style * Update src/transformers/integrations/flash_paged.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Apply style fixes --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 53838ed commit 97ca0b4

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

examples/pytorch/continuous_batching.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def generate_simple(
4040
attn_impl = {
4141
"sdpa_paged": "sdpa",
4242
"eager_paged": "eager",
43-
"flash_paged": "flash_attention_2",
43+
"paged_attention": "eager", # TODO: this does not work on AMD docker
44+
"flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker
4445
}[attn_impl]
4546

4647
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)

src/transformers/integrations/flash_paged.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,21 @@
66
from ..utils import is_flash_attn_2_available
77

88

9+
# For some reason, if we dont assign the function to a variable here, it will be garbage collected
910
try:
1011
if is_flash_attn_2_available():
1112
from flash_attn import flash_attn_varlen_func # noqa: F401
12-
except Exception:
13-
pass
13+
14+
FLASH_ATTN_VARLEN_FUNC = flash_attn_varlen_func
15+
else:
16+
raise RuntimeError(
17+
"Flash Attention 2 is not installed. Please refer to https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install it"
18+
)
19+
except Exception as e:
20+
msg = repr(e)
21+
22+
def FLASH_ATTN_VARLEN_FUNC(*args, **kwargs):
23+
raise Exception(f"flash_attn_varlen_func is not available: {msg}")
1424

1525

1626
def paged_attention_forward(
@@ -63,6 +73,8 @@ def paged_attention_forward(
6373

6474
if implementation is not None and hasattr(implementation, "flash_attn_varlen_func"):
6575
flash_attn_varlen_func = implementation.flash_attn_varlen_func
76+
else:
77+
flash_attn_varlen_func = FLASH_ATTN_VARLEN_FUNC
6678

6779
custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}
6880

0 commit comments

Comments
 (0)