|
| 1 | +import logging |
| 2 | +import sys |
| 3 | +import unittest |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +import torch_xla |
| 8 | +import torch_xla.distributed.spmd as xs |
| 9 | +from torch_xla import runtime as xr |
| 10 | +from torch_xla._internal import tpu |
| 11 | +from torch_xla.distributed.spmd import Mesh |
| 12 | +from torch_xla.experimental.custom_kernel import flash_attention |
| 13 | + |
| 14 | +from torch_xla.experimental.splash_attention import ( |
| 15 | + SplashAttentionConfig, |
| 16 | + splash_attention, |
| 17 | +) |
| 18 | + |
| 19 | +import torch_xla.core.xla_builder as xb |
| 20 | + |
| 21 | +if xr.device_type() == "TPU": |
| 22 | + from torch_xla.experimental.custom_kernel import jax_import_guard |
| 23 | + |
| 24 | + jax_import_guard() |
| 25 | + import jax |
| 26 | + |
| 27 | + |
| 28 | +def with_jax_high_precision(func): |
| 29 | + |
| 30 | + def wrapper(*args, **kwargs): |
| 31 | + jax.config.update("jax_default_matmul_precision", "highest") |
| 32 | + try: |
| 33 | + result = func(*args, **kwargs) |
| 34 | + finally: |
| 35 | + jax.config.update("jax_default_matmul_precision", "default") |
| 36 | + return result |
| 37 | + |
| 38 | + return wrapper |
| 39 | + |
| 40 | + |
| 41 | +class SplashAttentionTest(unittest.TestCase): |
| 42 | + |
| 43 | + @with_jax_high_precision |
| 44 | + def setUp(self): |
| 45 | + # Common dimensions for all tests. Spalsh attention kernel requires |
| 46 | + # NUM_HEADS, SEQ_LEN, HEAD_DIM must >= 128. |
| 47 | + self.BATCH_SIZE = 4 |
| 48 | + # Test GQA with different Q and KV heads. |
| 49 | + self.NUM_Q_HEADS = 128 |
| 50 | + self.NUM_KV_HEADS = 64 |
| 51 | + self.NUM_HEADS = 128 |
| 52 | + self.SEQ_LEN = 128 |
| 53 | + self.HEAD_DIM = 128 |
| 54 | + self.partition_spec = (("data", "fsdp"), None, None, None) |
| 55 | + segment_ids_partition_spec = (("data", "fsdp"), None) |
| 56 | + self.config = SplashAttentionConfig( |
| 57 | + mesh=str(xs.get_global_mesh()), |
| 58 | + qkv_partition_spec=self.partition_spec, |
| 59 | + segment_ids_partition_spec=segment_ids_partition_spec, |
| 60 | + ) |
| 61 | + self.q, self.k, self.v, self.q_sa, self.k_sa, self.v_sa = self.ab_comparsion_input_generation( |
| 62 | + ) |
| 63 | + segment_ids = torch.zeros(self.BATCH_SIZE, self.SEQ_LEN).to("xla") |
| 64 | + for i in range(self.BATCH_SIZE): |
| 65 | + segment_ids[i, :] = i |
| 66 | + self.segment_ids_sa = segment_ids.clone().detach() |
| 67 | + self.o = flash_attention( |
| 68 | + self.q, |
| 69 | + self.k, |
| 70 | + self.v, |
| 71 | + True, |
| 72 | + segment_ids, |
| 73 | + segment_ids, |
| 74 | + partition_spec=self.partition_spec, |
| 75 | + mesh=xs.get_global_mesh(), |
| 76 | + ) |
| 77 | + torch_xla.sync() |
| 78 | + loss = torch.sum(self.o) |
| 79 | + loss.backward() |
| 80 | + torch_xla.sync() |
| 81 | + self.q_grad, k_grad, v_grad = self.q.grad, self.k.grad, self.v.grad |
| 82 | + with torch.no_grad(): |
| 83 | + self.k_grad = self.maybe_reduce_kv_grad(k_grad) |
| 84 | + self.v_grad = self.maybe_reduce_kv_grad(v_grad) |
| 85 | + |
| 86 | + def maybe_repeat_kv(self, hidden_state): |
| 87 | + if hidden_state.size(1) == self.NUM_Q_HEADS: |
| 88 | + return hidden_state |
| 89 | + num_kv_group = self.NUM_Q_HEADS // self.NUM_KV_HEADS |
| 90 | + return hidden_state.repeat_interleave(num_kv_group, dim=1) |
| 91 | + |
| 92 | + def maybe_reduce_kv_grad(self, hidden_state_grad): |
| 93 | + # For GQA, the kv grad shape is [BATCH_SIZE, NUM_Q_HEADS, SEQ_LEN, |
| 94 | + # HEAD_DIM]. We need to convert it back to [BATCH_SIZE, NUM_KV_HEADS, |
| 95 | + # SEQ_LEN, HEAD_DIM]. The returned grad should be sum over the kv heads over |
| 96 | + # each group to preserve the magnitude of gradients. |
| 97 | + if hidden_state_grad.size(1) == self.NUM_KV_HEADS: |
| 98 | + return hidden_state_grad |
| 99 | + num_kv_group = self.NUM_Q_HEADS // self.NUM_KV_HEADS |
| 100 | + return hidden_state_grad.view( |
| 101 | + self.BATCH_SIZE, |
| 102 | + self.NUM_KV_HEADS, |
| 103 | + num_kv_group, |
| 104 | + self.SEQ_LEN, |
| 105 | + self.HEAD_DIM, |
| 106 | + ).sum(dim=2) |
| 107 | + |
| 108 | + def ab_comparsion_input_generation(self): |
| 109 | + q = torch.randn( |
| 110 | + self.BATCH_SIZE, |
| 111 | + self.NUM_Q_HEADS, |
| 112 | + self.SEQ_LEN, |
| 113 | + self.HEAD_DIM, |
| 114 | + requires_grad=True).to("xla") |
| 115 | + k = torch.randn( |
| 116 | + self.BATCH_SIZE, |
| 117 | + self.NUM_KV_HEADS, |
| 118 | + self.SEQ_LEN, |
| 119 | + self.HEAD_DIM, |
| 120 | + requires_grad=True, |
| 121 | + ).to("xla") |
| 122 | + v = torch.randn( |
| 123 | + self.BATCH_SIZE, |
| 124 | + self.NUM_KV_HEADS, |
| 125 | + self.SEQ_LEN, |
| 126 | + self.HEAD_DIM, |
| 127 | + requires_grad=True, |
| 128 | + ).to("xla") |
| 129 | + q.retain_grad() |
| 130 | + k.retain_grad() |
| 131 | + v.retain_grad() |
| 132 | + q_sa = q.clone().detach().requires_grad_(True) |
| 133 | + k_sa = k.clone().detach().requires_grad_(True) |
| 134 | + v_sa = v.clone().detach().requires_grad_(True) |
| 135 | + # Repeat the kv tensors to match the q tensor heads. This is required for flash |
| 136 | + k = self.maybe_repeat_kv(k) |
| 137 | + k.retain_grad() |
| 138 | + v = self.maybe_repeat_kv(v) |
| 139 | + v.retain_grad() |
| 140 | + torch_xla.sync() |
| 141 | + return q, k, v, q_sa, k_sa, v_sa |
| 142 | + |
| 143 | + def _attention(self, q, k, v, *, attn_mask=None, ab=None): |
| 144 | + k = self.maybe_repeat_kv(k) |
| 145 | + v = self.maybe_repeat_kv(v) |
| 146 | + attn_weight = q @ k.transpose(-2, -1) |
| 147 | + if attn_mask is not None: |
| 148 | + # Masked out the unrelevant parts. |
| 149 | + attn_weight = attn_weight.masked_fill(attn_mask, |
| 150 | + torch.finfo(attn_weight.dtype).min) |
| 151 | + if ab is not None: |
| 152 | + attn_weight = attn_weight + ab |
| 153 | + attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) |
| 154 | + attn_output = attn_weight @ v |
| 155 | + return attn_output |
| 156 | + |
| 157 | + @unittest.skipIf(xr.device_type() != "TPU" or tpu.version() < 3, |
| 158 | + "This test only works on TPUv3+.") |
| 159 | + @with_jax_high_precision |
| 160 | + def test_splash_attention_base(self): |
| 161 | + q, k, v, q_sa, k_sa, v_sa = self.ab_comparsion_input_generation() |
| 162 | + attention_mask = torch.triu( |
| 163 | + torch.ones(self.SEQ_LEN, self.SEQ_LEN), diagonal=1).to("xla") |
| 164 | + |
| 165 | + o = self._attention(q, k, v, attn_mask=attention_mask) |
| 166 | + torch_xla.sync() |
| 167 | + loss = torch.sum(o) |
| 168 | + loss.backward() |
| 169 | + q_grad, k_grad, v_grad = q.grad, k.grad, v.grad |
| 170 | + torch_xla.sync() |
| 171 | + |
| 172 | + o_sa = splash_attention(q_sa, k_sa, v_sa, self.config.to_json()) |
| 173 | + torch_xla.sync() |
| 174 | + loss_sa = torch.sum(o_sa) |
| 175 | + loss_sa.backward() |
| 176 | + q_grad_sa, k_grad_sa, v_grad_sa = q_sa.grad, k_sa.grad, v_sa.grad |
| 177 | + torch_xla.sync() |
| 178 | + |
| 179 | + with torch.no_grad(): |
| 180 | + k_grad = self.maybe_reduce_kv_grad(k_grad) |
| 181 | + v_grad = self.maybe_reduce_kv_grad(v_grad) |
| 182 | + |
| 183 | + torch.testing.assert_close(o.cpu(), o_sa.cpu(), rtol=1e-3, atol=1e-5) |
| 184 | + |
| 185 | + for org_grad, sa_grad in zip([q_grad, k_grad, v_grad], |
| 186 | + [q_grad_sa, k_grad_sa, v_grad_sa], |
| 187 | + strict=False): |
| 188 | + torch.testing.assert_close( |
| 189 | + org_grad.cpu(), sa_grad.cpu(), rtol=1e-4, atol=1e-2) |
| 190 | + |
| 191 | + @unittest.skipIf(xr.device_type() != "TPU" or tpu.version() < 3, |
| 192 | + "This test only works on TPUv3+.") |
| 193 | + @with_jax_high_precision |
| 194 | + def test_splash_attention_sharding(self): |
| 195 | + n_devices = xr.global_runtime_device_count() |
| 196 | + q = ( |
| 197 | + torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN, |
| 198 | + self.HEAD_DIM).requires_grad_(True).to("xla")) |
| 199 | + k = ( |
| 200 | + torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN, |
| 201 | + self.HEAD_DIM).requires_grad_(True).to("xla")) |
| 202 | + v = ( |
| 203 | + torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN, |
| 204 | + self.HEAD_DIM).requires_grad_(True).to("xla")) |
| 205 | + o = splash_attention(q, k, v, self.config.to_json()) |
| 206 | + torch_xla.sync() |
| 207 | + self.assertEqual( |
| 208 | + torch_xla._XLAC._get_xla_sharding_spec(o), |
| 209 | + f"{{devices=[{n_devices},1,1,1]<=[{n_devices}]}}", |
| 210 | + ) |
| 211 | + |
| 212 | + @unittest.skipIf(xr.device_type() != "TPU" or tpu.version() < 3, |
| 213 | + "This test only works on TPUv3+.") |
| 214 | + @with_jax_high_precision |
| 215 | + def test_splash_attention_segment_id(self): |
| 216 | + # test the segment id in splash attention against the flash attention kernel |
| 217 | + q_sa = self.q_sa.clone().detach().requires_grad_(True) |
| 218 | + k_sa = self.k_sa.clone().detach().requires_grad_(True) |
| 219 | + v_sa = self.v_sa.clone().detach().requires_grad_(True) |
| 220 | + for i in [q_sa, k_sa, v_sa]: |
| 221 | + i.retain_grad() |
| 222 | + segment_ids_sa = self.segment_ids_sa.clone().detach() |
| 223 | + o_sa = splash_attention( |
| 224 | + q_sa, |
| 225 | + k_sa, |
| 226 | + v_sa, |
| 227 | + self.config.to_json(), |
| 228 | + decoder_segment_ids=segment_ids_sa.to("xla")) |
| 229 | + loss_sa = torch.sum(o_sa) |
| 230 | + loss_sa.backward() |
| 231 | + q_grad_sa, k_grad_sa, v_grad_sa = q_sa.grad, k_sa.grad, v_sa.grad |
| 232 | + torch_xla.sync() |
| 233 | + torch.testing.assert_close(self.o.cpu(), o_sa.cpu(), rtol=1e-3, atol=1e-5) |
| 234 | + for org_grad, sa_grad in zip([self.q_grad, self.k_grad, self.v_grad], |
| 235 | + [q_grad_sa, k_grad_sa, v_grad_sa], |
| 236 | + strict=False): |
| 237 | + torch.testing.assert_close( |
| 238 | + org_grad.cpu(), sa_grad.cpu(), rtol=1e-4, atol=1e-2) |
| 239 | + |
| 240 | + @unittest.skipIf(xr.device_type() != "TPU" or tpu.version() < 3, |
| 241 | + "This test only works on TPUv3+.") |
| 242 | + @with_jax_high_precision |
| 243 | + def test_splash_attention_aot_traceable(self): |
| 244 | + from functorch.compile import aot_function, make_boxed_func |
| 245 | + |
| 246 | + def compiler(gm, _): |
| 247 | + return make_boxed_func(gm) |
| 248 | + |
| 249 | + compiled_splash_attention = aot_function( |
| 250 | + splash_attention, fw_compiler=compiler) |
| 251 | + |
| 252 | + q_sa = self.q_sa.clone().detach().requires_grad_(True) |
| 253 | + k_sa = self.k_sa.clone().detach().requires_grad_(True) |
| 254 | + v_sa = self.v_sa.clone().detach().requires_grad_(True) |
| 255 | + for i in [q_sa, k_sa, v_sa]: |
| 256 | + i.retain_grad() |
| 257 | + segment_ids_sa = self.segment_ids_sa.clone().detach() |
| 258 | + o_sa = compiled_splash_attention( |
| 259 | + q_sa, |
| 260 | + k_sa, |
| 261 | + v_sa, |
| 262 | + self.config.to_json(), |
| 263 | + decoder_segment_ids=segment_ids_sa) |
| 264 | + torch_xla.sync() |
| 265 | + loss_sa = torch.sum(o_sa) |
| 266 | + loss_sa.backward() |
| 267 | + torch_xla.sync() |
| 268 | + q_grad_sa, k_grad_sa, v_grad_sa = q_sa.grad, k_sa.grad, v_sa.grad |
| 269 | + |
| 270 | + torch.testing.assert_close(self.o.cpu(), o_sa.cpu(), rtol=1e-3, atol=1e-5) |
| 271 | + for org_grad, sa_grad in zip([self.q_grad, self.k_grad, self.v_grad], |
| 272 | + [q_grad_sa, k_grad_sa, v_grad_sa], |
| 273 | + strict=False): |
| 274 | + torch.testing.assert_close( |
| 275 | + org_grad.cpu(), sa_grad.cpu(), rtol=1e-4, atol=1e-2) |
| 276 | + |
| 277 | + @unittest.skipIf(xr.device_type() != "TPU" or tpu.version() < 3, |
| 278 | + "This test only works on TPUv3+.") |
| 279 | + @with_jax_high_precision # remove the decorator will cause failure in other tests :) |
| 280 | + def test_splash_attention_cache_hit(self): |
| 281 | + xb._JAX_TO_XLA_COMPUTATION_CACHE.clear() |
| 282 | + starting_cache_misses = xb._jax_to_xla_computation_cache_num_misses() |
| 283 | + q = self.q_sa.clone().detach().requires_grad_(True) |
| 284 | + k = self.k_sa.clone().detach().requires_grad_(True) |
| 285 | + v = self.v_sa.clone().detach().requires_grad_(True) |
| 286 | + segment_ids = self.segment_ids_sa.clone().detach() |
| 287 | + o = splash_attention( |
| 288 | + q, |
| 289 | + k, |
| 290 | + v, |
| 291 | + self.config.to_json(), |
| 292 | + decoder_segment_ids=segment_ids.to("xla")) |
| 293 | + loss = torch.sum(o) |
| 294 | + loss.backward() |
| 295 | + torch_xla.sync() |
| 296 | + |
| 297 | + q = self.q_sa.clone().detach().requires_grad_(True) |
| 298 | + k = self.k_sa.clone().detach().requires_grad_(True) |
| 299 | + v = self.v_sa.clone().detach().requires_grad_(True) |
| 300 | + q = q * 2 |
| 301 | + segment_ids = self.segment_ids_sa.clone().detach() |
| 302 | + o = splash_attention( |
| 303 | + q, |
| 304 | + k, |
| 305 | + v, |
| 306 | + self.config.to_json(), |
| 307 | + decoder_segment_ids=segment_ids.to("xla")) |
| 308 | + loss = torch.sum(o) |
| 309 | + loss.backward() |
| 310 | + torch_xla.sync() |
| 311 | + ending_cache_misses = xb._jax_to_xla_computation_cache_num_misses() |
| 312 | + # There are 2 misses because we run both forward (+1 miss) and backward (+1 |
| 313 | + # miss) pass. |
| 314 | + self.assertEqual(ending_cache_misses - starting_cache_misses, 2) |
| 315 | + |
| 316 | + |
| 317 | +if __name__ == "__main__": |
| 318 | + logging.getLogger().setLevel(logging.INFO) |
| 319 | + torch.set_default_dtype(torch.float32) |
| 320 | + torch_xla._XLAC._xla_set_mat_mul_precision("highest") |
| 321 | + torch.manual_seed(42) |
| 322 | + xr.use_spmd() |
| 323 | + num_devices = xr.global_runtime_device_count() |
| 324 | + mesh_shape = (num_devices // 2, 2) |
| 325 | + device_ids = np.array(range(num_devices)) |
| 326 | + mesh = Mesh(device_ids, mesh_shape, ("data", "fsdp")) |
| 327 | + xs.set_global_mesh(mesh) |
| 328 | + test = unittest.main() |
| 329 | + sys.exit(0 if test.result.wasSuccessful() else 1) |
0 commit comments