Skip to content

Commit 92b106f

Browse files
committed
update
1 parent 87e4c4f commit 92b106f

File tree

3 files changed

+37
-8
lines changed

3 files changed

+37
-8
lines changed

llm/run_pretrain.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ class ModelArguments:
223223
default=None,
224224
metadata={"help": "num_hidden_layers."},
225225
)
226+
use_casual_mask: Optional[bool] = field(
227+
default=True,
228+
metadata={"help": "whether to use casual mask"},
229+
)
226230

227231

228232
def create_pretrained_dataset(
@@ -476,6 +480,7 @@ def main():
476480
config.pp_recompute_interval = model_args.pp_recompute_interval
477481
config.recompute_use_reentrant = model_args.recompute_use_reentrant
478482
config.use_recompute = training_args.recompute
483+
config.use_casual_mask = model_args.use_casual_mask
479484

480485
config.tensor_parallel_degree = training_args.tensor_parallel_degree
481486
config.tensor_parallel_rank = training_args.tensor_parallel_rank

paddlenlp/trainer/trainer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1892,7 +1892,6 @@ def get_expected_keys(inputs, keys):
18921892
optimizer._set_broadcast_overlap(True, model)
18931893

18941894
self.optimizer = optimizer
1895-
18961895
# pure tesnor parallel mode, no pipeline_parallel, no sharding.
18971896
if (
18981897
not in_pipeline_parallel_mode
@@ -1908,6 +1907,21 @@ def get_expected_keys(inputs, keys):
19081907
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
19091908
self.optimizer = fleet.distributed_optimizer(self.optimizer)
19101909

1910+
# stage1 has v1 and v2 version
1911+
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
1912+
if "split_param" in self.args.sharding_parallel_config:
1913+
if (
1914+
hasattr(self.optimizer, "_set_all_gather_overlap_forward")
1915+
and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config
1916+
):
1917+
self.optimizer._set_all_gather_overlap_forward(True, model)
1918+
else:
1919+
if (
1920+
hasattr(self.optimizer, "_set_broadcast_overlap")
1921+
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
1922+
):
1923+
self.optimizer._set_broadcast_overlap(True, model)
1924+
19111925
return model
19121926

19131927
def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:

paddlenlp/transformers/llama/modeling.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ def _get_interleave_power_of_2(n):
118118
def build_alibi_tensor(
119119
bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1
120120
) -> Tensor:
121-
attention_mask = bool_attention_mask.astype("float32")
122-
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
121+
batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1]
123122
slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32")
124123
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand(
125124
[num_heads, -1, -1]
@@ -307,7 +306,7 @@ def is_casual_mask(attention_mask):
307306

308307
def _make_causal_mask(input_ids_shape, past_key_values_length):
309308
"""
310-
Make causal mask used for self-attention
309+
Make casual mask used for self-attention
311310
"""
312311
batch_size, target_length = input_ids_shape # target_length: seq_len
313312

@@ -1533,12 +1532,23 @@ def forward(
15331532
if position_ids is None:
15341533
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
15351534

1536-
attention_mask = self._prepare_decoder_attention_mask(
1537-
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
1538-
) # [bs, 1, seq_len, seq_len]
1535+
is_casual_mask = (
1536+
True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False
1537+
)
1538+
if is_casual_mask:
1539+
attention_mask = None
1540+
else:
1541+
attention_mask = self._prepare_decoder_attention_mask(
1542+
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
1543+
) # [bs, 1, seq_len, seq_len]
1544+
15391545
is_casual = False
1546+
15401547
if self.config.use_flash_attention and get_env_device() != "gcu":
1541-
is_casual = is_casual_mask(attention_mask)
1548+
if is_casual_mask:
1549+
is_casual = True
1550+
else:
1551+
is_casual = is_casual_mask(attention_mask)
15421552
if get_env_device() != "npu":
15431553
if is_casual and alibi is None:
15441554
attention_mask = None

0 commit comments

Comments
 (0)