|
20 | 20 |
|
21 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
22 | 22 | from transformers.generation.continuous_batching.cache import group_layers_by_attn_type |
| 23 | +from transformers.generation.continuous_batching.continuous_api import build_attention_mask |
23 | 24 | from transformers.testing_utils import Expectations, require_kernels, require_torch_gpu, slow |
24 | 25 |
|
25 | 26 |
|
@@ -88,6 +89,48 @@ def test_group_layers( |
88 | 89 | f"Test failed for: {layer_types_str = }, {sliding_window = }, {group_types = }", |
89 | 90 | ) |
90 | 91 |
|
| 92 | + @parameterized.expand( |
| 93 | + [ |
| 94 | + ([0, 4], [0, 4], 1, ["1000", "1100", "1110", "1111"]), |
| 95 | + ([0, 4], [0, 4], 2, ["1000", "1100", "0110", "0011"]), |
| 96 | + ([0, 3], [0, 5], 1, ["11100", "11110", "11111"]), |
| 97 | + ([0, 3], [0, 5], 3, ["11100", "01110", "00111"]), |
| 98 | + ([0, 3, 6], [0, 3, 6], 1, ["100000", "110000", "111000", "000100", "000110", "000111"]), |
| 99 | + ([0, 3, 6], [0, 3, 6], 2, ["100000", "110000", "011000", "000100", "000110", "000011"]), |
| 100 | + ] |
| 101 | + ) |
| 102 | + def test_attention_mask( |
| 103 | + self, |
| 104 | + cumulative_seqlens_q: list[int], |
| 105 | + cumulative_seqlens_k: list[int], |
| 106 | + sliding_window: int, # the sliding window size, 1 means no sliding window |
| 107 | + str_expected_mask: list[str], # the attention mask, broken down by line as a string of 0s and 1s |
| 108 | + ) -> None: |
| 109 | + # Build expected mask |
| 110 | + minus_inf = torch.finfo(torch.float32).min |
| 111 | + expected_mask = torch.empty((cumulative_seqlens_q[-1], cumulative_seqlens_k[-1]), dtype=torch.float32) |
| 112 | + for i, line in enumerate(str_expected_mask): |
| 113 | + expected_mask[i, :] = torch.tensor([minus_inf if c == "0" else 0 for c in line]) |
| 114 | + # Build actual mask |
| 115 | + actual_mask = torch.full_like(expected_mask, minus_inf) # function modifies in place |
| 116 | + build_attention_mask( |
| 117 | + actual_mask, torch.tensor(cumulative_seqlens_q), torch.tensor(cumulative_seqlens_k), sliding_window |
| 118 | + ) |
| 119 | + # Check that the actual mask matches the expected mask |
| 120 | + matches = (expected_mask == actual_mask).all() |
| 121 | + # If it doesn't match, print the masks in a readable form and fail the test |
| 122 | + if not matches: |
| 123 | + str_mask = [ |
| 124 | + "".join("1" if x == 0 else "0" for x in token_attn_vector) for token_attn_vector in actual_mask |
| 125 | + ] |
| 126 | + str_mask = "\n".join(str_mask) |
| 127 | + str_expected_mask = "\n".join(str_expected_mask) |
| 128 | + self.fail( |
| 129 | + f"Test failed for: {cumulative_seqlens_q = }, {cumulative_seqlens_k = }, {sliding_window = }\n" |
| 130 | + f"Expected mask:\n{str_expected_mask}\n" |
| 131 | + f"Actual mask:\n{str_mask}" |
| 132 | + ) |
| 133 | + |
91 | 134 | def _continuous_batching_parity( |
92 | 135 | self, model_id: str, attn_implementation: str, expected_outputs: dict[str, str] |
93 | 136 | ) -> None: |
|
0 commit comments