Skip to content

Commit 7aaa788

Browse files
authored
Support Sharding Overlap (PaddlePaddle#8473)
* update * update is_casual_mask to use_casual_mask * update by environment
1 parent c6f4159 commit 7aaa788

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1892,7 +1892,6 @@ def get_expected_keys(inputs, keys):
18921892
optimizer._set_broadcast_overlap(True, model)
18931893

18941894
self.optimizer = optimizer
1895-
18961895
# pure tesnor parallel mode, no pipeline_parallel, no sharding.
18971896
if (
18981897
not in_pipeline_parallel_mode
@@ -1908,6 +1907,21 @@ def get_expected_keys(inputs, keys):
19081907
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
19091908
self.optimizer = fleet.distributed_optimizer(self.optimizer)
19101909

1910+
# stage1 has v1 and v2 version
1911+
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
1912+
if "split_param" in self.args.sharding_parallel_config:
1913+
if (
1914+
hasattr(self.optimizer, "_set_all_gather_overlap_forward")
1915+
and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config
1916+
):
1917+
self.optimizer._set_all_gather_overlap_forward(True, model)
1918+
else:
1919+
if (
1920+
hasattr(self.optimizer, "_set_broadcast_overlap")
1921+
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
1922+
):
1923+
self.optimizer._set_broadcast_overlap(True, model)
1924+
19111925
return model
19121926

19131927
def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:

paddlenlp/transformers/llama/modeling.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,15 @@ def _get_interleave_power_of_2(n):
115115
)
116116

117117

118+
def get_use_casual_mask():
119+
"""Get the value of the 'USE_CASUAL_MASK' environment variable."""
120+
return os.getenv("USE_CASUAL_MASK", "False")
121+
122+
118123
def build_alibi_tensor(
119124
bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1
120125
) -> Tensor:
121-
attention_mask = bool_attention_mask.astype("float32")
122-
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
126+
batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1]
123127
slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32")
124128
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand(
125129
[num_heads, -1, -1]
@@ -307,7 +311,7 @@ def is_casual_mask(attention_mask):
307311

308312
def _make_causal_mask(input_ids_shape, past_key_values_length):
309313
"""
310-
Make causal mask used for self-attention
314+
Make casual mask used for self-attention
311315
"""
312316
batch_size, target_length = input_ids_shape # target_length: seq_len
313317

@@ -1543,12 +1547,22 @@ def forward(
15431547
if position_ids is None:
15441548
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
15451549

1546-
attention_mask = self._prepare_decoder_attention_mask(
1547-
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
1548-
) # [bs, 1, seq_len, seq_len]
1550+
use_casual_mask = get_use_casual_mask()
1551+
1552+
if use_casual_mask:
1553+
attention_mask = None
1554+
else:
1555+
attention_mask = self._prepare_decoder_attention_mask(
1556+
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
1557+
) # [bs, 1, seq_len, seq_len]
1558+
15491559
is_casual = False
1560+
15501561
if self.config.use_flash_attention and get_env_device() != "gcu":
1551-
is_casual = is_casual_mask(attention_mask)
1562+
if use_casual_mask:
1563+
is_casual = True
1564+
else:
1565+
is_casual = is_casual_mask(attention_mask)
15521566
if get_env_device() != "npu":
15531567
if is_casual and alibi is None:
15541568
attention_mask = None

0 commit comments

Comments
 (0)