@@ -118,8 +118,7 @@ def _get_interleave_power_of_2(n):
118118def build_alibi_tensor (
119119 bool_attention_mask : Tensor , num_heads : int , dtype : paddle .dtype , tensor_parallel_degree = 1
120120) -> Tensor :
121- attention_mask = bool_attention_mask .astype ("float32" )
122- batch_size , seq_length = attention_mask .shape [0 ], attention_mask .shape [- 1 ]
121+ batch_size , seq_length = bool_attention_mask .shape [0 ], bool_attention_mask .shape [- 1 ]
123122 slopes = paddle .to_tensor (_get_interleave (num_heads ), dtype = "float32" )
124123 alibi = slopes .unsqueeze (axis = [1 , 2 ]) * paddle .arange (seq_length , dtype = "float32" ).unsqueeze (axis = [0 , 1 ]).expand (
125124 [num_heads , - 1 , - 1 ]
@@ -307,7 +306,7 @@ def is_casual_mask(attention_mask):
307306
308307def _make_causal_mask (input_ids_shape , past_key_values_length ):
309308 """
310- Make causal mask used for self-attention
309+ Make casual mask used for self-attention
311310 """
312311 batch_size , target_length = input_ids_shape # target_length: seq_len
313312
@@ -1533,12 +1532,23 @@ def forward(
15331532 if position_ids is None :
15341533 position_ids = paddle .arange (seq_length , dtype = "int64" ).expand ((batch_size , seq_length ))
15351534
1536- attention_mask = self ._prepare_decoder_attention_mask (
1537- attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
1538- ) # [bs, 1, seq_len, seq_len]
1535+ is_casual_mask = (
1536+ True if hasattr (self .config , "use_casual_mask" ) and self .config .use_casual_mask is True else False
1537+ )
1538+ if is_casual_mask :
1539+ attention_mask = None
1540+ else :
1541+ attention_mask = self ._prepare_decoder_attention_mask (
1542+ attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
1543+ ) # [bs, 1, seq_len, seq_len]
1544+
15391545 is_casual = False
1546+
15401547 if self .config .use_flash_attention and get_env_device () != "gcu" :
1541- is_casual = is_casual_mask (attention_mask )
1548+ if is_casual_mask :
1549+ is_casual = True
1550+ else :
1551+ is_casual = is_casual_mask (attention_mask )
15421552 if get_env_device () != "npu" :
15431553 if is_casual and alibi is None :
15441554 attention_mask = None
0 commit comments