@@ -368,16 +368,20 @@ def forward(
368368 key_states  =  repeat_kv (key_states , self .num_key_value_groups )
369369 value_states  =  repeat_kv (value_states , self .num_key_value_groups )
370370
371-  attn_weights  =  torch .matmul (query_states , key_states .transpose (2 , 3 )) /  math .sqrt (self .head_dim )
371+  #  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
372372
373-  if  attention_mask  is  not None : # no matter the length, we just slice it 
374-  causal_mask  =  attention_mask [:, :, :, : key_states .shape [- 2 ]]
375-  attn_weights  =  attn_weights  +  causal_mask 
373+  #  if attention_mask is not None: # no matter the length, we just slice it
374+  #   causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
375+  #   attn_weights = attn_weights + causal_mask
376376
377-  # upcast attention to fp32 
378-  attn_weights  =  nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query_states .dtype )
379-  attn_weights  =  nn .functional .dropout (attn_weights , p = self .attention_dropout , training = self .training )
380-  attn_output  =  torch .matmul (attn_weights , value_states )
377+  # # upcast attention to fp32 
378+  # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 
379+  # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 
380+  # attn_output = torch.matmul(attn_weights, value_states) 
381+ 
382+  # Integrated with PyTorch/XLA Pallas Flash Attention: 
383+  from  torch_xla .experimental .custom_kernel  import  flash_attention 
384+  attn_output  =  flash_attention (query_states , key_states , value_states , partition_spec = ('fsdp' , None , None , None ))
381385
382386 if  attn_output .size () !=  (bsz , self .num_heads , q_len , self .head_dim ):
383387 raise  ValueError (
0 commit comments