Skip to content

Commit 31c3b55

Browse files
authored
[BUG] Fix build train valid test datasets (#8823)
1 parent 161fb67 commit 31c3b55

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

paddlenlp/data/causal_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def build_train_valid_test_datasets(
147147
# Parse the values.
148148
output = get_datasets_weights_and_num_samples(data_prefix, train_val_test_num_samples)
149149
prefixes, weights, datasets_train_valid_test_num_samples = output
150-
train_num_samples, valid_num_samples, test_num_samples = map(sum, zip(*datasets_train_valid_test_num_samples))
150+
# NOTE: megatron/gpt_dataset.py has been updated. When creating BlendableDataset, we will use the raw train_val_test_num_samples instead of the expanded ones.
151+
# Please refer to https://github.com/NVIDIA/NeMo/blob/72f630d087d45655b1a069dc72debf01dfdbdb2d/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py#L74-L80 for more information
152+
train_num_samples, valid_num_samples, test_num_samples = datasets_train_valid_test_num_samples
151153

152154
# Build individual datasets.
153155
train_datasets = []

0 commit comments

Comments
 (0)