Skip to content

Commit 5bdb704

Browse files
Fix sliding window attn mask (#41228)
* Fix sliding window attn mask * Clearer test * Apply style fixes * If Picasso made ascii drawings he would have made this --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent a61fc6a commit 5bdb704

File tree

2 files changed

+95
-3
lines changed

2 files changed

+95
-3
lines changed

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,56 @@ def build_attention_mask(
4242
) -> None:
4343
"""Builds an attention mask inplace using the cumulative seqlens of the query and key. If given a sliding window, it
4444
will also apply a sliding window mask on top. The attention mask is not boolean, it uses zeroes and -inf (or its
45-
equivalent) so it's more of an attention score bias tensor."""
45+
equivalent) so it's more of an attention score bias tensor.
46+
The attention mask is a block-diagonal matrix, with each block an attention mask for a single query-key pair.
47+
Each of those block is built from a causal mask and, if there is a sliding window, a sliding window mask.
48+
49+
An example is represented below, with seqlen_k = 8, seqlen_q = 4 and sliding_window = 6:
50+
51+
CAUSAL MASK:
52+
53+
█ █ █ █ █ ░ ░ ░
54+
█ █ █ █ █ █ ░ ░
55+
█ █ █ █ █ █ █ ░
56+
█ █ █ █ █ █ █ █
57+
58+
SLIDING WINDOW MASK:
59+
┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 8 - 4 - 6 = -2 offset to the right
60+
<─┴─>
61+
░ █ | █ █ █ █ █ █ █ █
62+
░ ░ | █ █ █ █ █ █ █ █
63+
░ ░ | ░ █ █ █ █ █ █ █
64+
░ ░ | ░ ░ █ █ █ █ █ █
65+
66+
ATTENTION MASK (sum of causal and sliding window masks):
67+
68+
█ █ █ █ █ ░ ░ ░
69+
█ █ █ █ █ █ ░ ░
70+
░ █ █ █ █ █ █ ░
71+
░ ░ █ █ █ █ █ █
72+
73+
Another example with seqlen_k = 5, seqlen_q = 3 and sliding_window = 2:
74+
75+
CAUSAL MASK:
76+
77+
█ █ █ ░ ░
78+
█ █ █ █ ░
79+
█ █ █ █ █
80+
81+
SLIDING WINDOW MASK:
82+
┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 5 - 3 - 2 = 0 offset to the right
83+
<┴>
84+
| ░ █ █ █ █
85+
| ░ ░ █ █ █
86+
| ░ ░ ░ █ █
87+
88+
ATTENTION MASK (sum of causal and sliding window masks):
89+
90+
░ █ █ ░ ░
91+
░ ░ █ █ ░
92+
░ ░ ░ █ █
93+
94+
"""
4695
min_value = torch.finfo(attention_mask.dtype).min
4796
for i in range(len(cumulative_seqlens_q) - 1):
4897
seqlen_q = cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]
@@ -63,8 +112,8 @@ def build_attention_mask(
63112
masked = torch.triu(minus_inf, diagonal=causal_diagonal)
64113
# Apply sliding window mask if needed
65114
if sliding_window > 1:
66-
sliding_diagonal = seqlen_k - seqlen_q + sliding_window
67-
masked = torch.tril(masked, diagonal=sliding_diagonal)
115+
sliding_diagonal = seqlen_k - seqlen_q - sliding_window
116+
masked += torch.tril(minus_inf, diagonal=sliding_diagonal)
68117
# Replace in attention mask
69118
attention_mask[..., query_range, key_range] = masked
70119

tests/generation/test_continuous_batching.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
2222
from transformers.generation.continuous_batching.cache import group_layers_by_attn_type
23+
from transformers.generation.continuous_batching.continuous_api import build_attention_mask
2324
from transformers.testing_utils import Expectations, require_kernels, require_torch_gpu, slow
2425

2526

@@ -88,6 +89,48 @@ def test_group_layers(
8889
f"Test failed for: {layer_types_str = }, {sliding_window = }, {group_types = }",
8990
)
9091

92+
@parameterized.expand(
93+
[
94+
([0, 4], [0, 4], 1, ["1000", "1100", "1110", "1111"]),
95+
([0, 4], [0, 4], 2, ["1000", "1100", "0110", "0011"]),
96+
([0, 3], [0, 5], 1, ["11100", "11110", "11111"]),
97+
([0, 3], [0, 5], 3, ["11100", "01110", "00111"]),
98+
([0, 3, 6], [0, 3, 6], 1, ["100000", "110000", "111000", "000100", "000110", "000111"]),
99+
([0, 3, 6], [0, 3, 6], 2, ["100000", "110000", "011000", "000100", "000110", "000011"]),
100+
]
101+
)
102+
def test_attention_mask(
103+
self,
104+
cumulative_seqlens_q: list[int],
105+
cumulative_seqlens_k: list[int],
106+
sliding_window: int, # the sliding window size, 1 means no sliding window
107+
str_expected_mask: list[str], # the attention mask, broken down by line as a string of 0s and 1s
108+
) -> None:
109+
# Build expected mask
110+
minus_inf = torch.finfo(torch.float32).min
111+
expected_mask = torch.empty((cumulative_seqlens_q[-1], cumulative_seqlens_k[-1]), dtype=torch.float32)
112+
for i, line in enumerate(str_expected_mask):
113+
expected_mask[i, :] = torch.tensor([minus_inf if c == "0" else 0 for c in line])
114+
# Build actual mask
115+
actual_mask = torch.full_like(expected_mask, minus_inf) # function modifies in place
116+
build_attention_mask(
117+
actual_mask, torch.tensor(cumulative_seqlens_q), torch.tensor(cumulative_seqlens_k), sliding_window
118+
)
119+
# Check that the actual mask matches the expected mask
120+
matches = (expected_mask == actual_mask).all()
121+
# If it doesn't match, print the masks in a readable form and fail the test
122+
if not matches:
123+
str_mask = [
124+
"".join("1" if x == 0 else "0" for x in token_attn_vector) for token_attn_vector in actual_mask
125+
]
126+
str_mask = "\n".join(str_mask)
127+
str_expected_mask = "\n".join(str_expected_mask)
128+
self.fail(
129+
f"Test failed for: {cumulative_seqlens_q = }, {cumulative_seqlens_k = }, {sliding_window = }\n"
130+
f"Expected mask:\n{str_expected_mask}\n"
131+
f"Actual mask:\n{str_mask}"
132+
)
133+
91134
def _continuous_batching_parity(
92135
self, model_id: str, attn_implementation: str, expected_outputs: dict[str, str]
93136
) -> None:

0 commit comments

Comments
 (0)