Skip to content

Commit 850ac3c

Browse files
committed
Enable attn_weight scalar
1 parent 05cef15 commit 850ac3c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/models/llama/modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ def forward(
381381
attn_output = torch.matmul(attn_weights, value_states)
382382
else:
383383
# Integrated with PyTorch/XLA Pallas Flash Attention:
384-
# TODO: enable 1 / math.sqrt(self.head_dim).
385384
from torch_xla.experimental.custom_kernel import flash_attention
385+
query_states = query_states / math.sqrt(self.head_dim)
386386
attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=('fsdp', None, None, None))
387387

388388
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):

0 commit comments

Comments
 (0)