Skip to content

Commit b6fa01b

Browse files
committed
enable flash attention
1 parent 8d60bf1 commit b6fa01b

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

src/transformers/models/llama/modeling_llama.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -368,16 +368,20 @@ def forward(
368368
key_states = repeat_kv(key_states, self.num_key_value_groups)
369369
value_states = repeat_kv(value_states, self.num_key_value_groups)
370370

371-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
371+
# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
372372

373-
if attention_mask is not None: # no matter the length, we just slice it
374-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
375-
attn_weights = attn_weights + causal_mask
373+
# if attention_mask is not None: # no matter the length, we just slice it
374+
# causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
375+
# attn_weights = attn_weights + causal_mask
376376

377-
# upcast attention to fp32
378-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
379-
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
380-
attn_output = torch.matmul(attn_weights, value_states)
377+
# # upcast attention to fp32
378+
# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
379+
# attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
380+
# attn_output = torch.matmul(attn_weights, value_states)
381+
382+
# Integrated with PyTorch/XLA Pallas Flash Attention:
383+
from torch_xla.experimental.custom_kernel import flash_attention
384+
attn_output = flash_attention(query_states, key_states, value_states, partition_spec=('fsdp', None, None, None))
381385

382386
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
383387
raise ValueError(

0 commit comments

Comments
 (0)