@@ -115,11 +115,15 @@ def _get_interleave_power_of_2(n):
115115 )
116116
117117
118+ def get_use_casual_mask ():
119+ """Get the value of the 'USE_CASUAL_MASK' environment variable."""
120+ return os .getenv ("USE_CASUAL_MASK" , "False" )
121+
122+
118123def build_alibi_tensor (
119124 bool_attention_mask : Tensor , num_heads : int , dtype : paddle .dtype , tensor_parallel_degree = 1
120125) -> Tensor :
121- attention_mask = bool_attention_mask .astype ("float32" )
122- batch_size , seq_length = attention_mask .shape [0 ], attention_mask .shape [- 1 ]
126+ batch_size , seq_length = bool_attention_mask .shape [0 ], bool_attention_mask .shape [- 1 ]
123127 slopes = paddle .to_tensor (_get_interleave (num_heads ), dtype = "float32" )
124128 alibi = slopes .unsqueeze (axis = [1 , 2 ]) * paddle .arange (seq_length , dtype = "float32" ).unsqueeze (axis = [0 , 1 ]).expand (
125129 [num_heads , - 1 , - 1 ]
@@ -307,7 +311,7 @@ def is_casual_mask(attention_mask):
307311
308312def _make_causal_mask (input_ids_shape , past_key_values_length ):
309313 """
310- Make causal mask used for self-attention
314+ Make casual mask used for self-attention
311315 """
312316 batch_size , target_length = input_ids_shape # target_length: seq_len
313317
@@ -1543,12 +1547,22 @@ def forward(
15431547 if position_ids is None :
15441548 position_ids = paddle .arange (seq_length , dtype = "int64" ).expand ((batch_size , seq_length ))
15451549
1546- attention_mask = self ._prepare_decoder_attention_mask (
1547- attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
1548- ) # [bs, 1, seq_len, seq_len]
1550+ use_casual_mask = get_use_casual_mask ()
1551+
1552+ if use_casual_mask :
1553+ attention_mask = None
1554+ else :
1555+ attention_mask = self ._prepare_decoder_attention_mask (
1556+ attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
1557+ ) # [bs, 1, seq_len, seq_len]
1558+
15491559 is_casual = False
1560+
15501561 if self .config .use_flash_attention and get_env_device () != "gcu" :
1551- is_casual = is_casual_mask (attention_mask )
1562+ if use_casual_mask :
1563+ is_casual = True
1564+ else :
1565+ is_casual = is_casual_mask (attention_mask )
15521566 if get_env_device () != "npu" :
15531567 if is_casual and alibi is None :
15541568 attention_mask = None
0 commit comments