Skip to content

Commit 14764e9

Browse files
author
yunfan
committed
update
1 parent 4bb5dc6 commit 14764e9

File tree

7 files changed

+52
-46
lines changed

7 files changed

+52
-46
lines changed

pretrain/README.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,22 @@
22

33
The code of pre-training CPT is based on [Megatron-LM](https://github.com/NVIDIA/Megatron-LM).
44

5-
For **Setup**, **Data Processing** and **Training** of CPT, you can refer to the [README](README_megatron.md) of Megatron-LM. And the package [jieba_fast](https://github.com/deepcs233/jieba_fast) is needed for Whole Word Masking pre-training.
5+
For **Setup**, **Data Processing** of CPT, you can refer to the [README](README_megatron.md) of Megatron-LM. And the package [jieba_fast](https://github.com/deepcs233/jieba_fast) is needed for Whole Word Masking pre-training.
66

7-
After processing the data, place the `.bin` and `.idx` files into `./dataset/`. And place vocab files into `vocab/bert_zh_vocab/`. Then, use the scripts `run_pretrain_bart.sh` and `run_pretrain_cpt.sh` to train Chinese BART and CPT, respectively.
7+
## Training
8+
Firstly, prepare files in the following folders:
9+
- `dataset/`: Place the `.bin` and `.idx` files that preprocessed from raw text.
10+
- `vocab/`: Place the vocab files and model config file.
11+
- `roberta_zh/`: Place the checkpoint of Chinese RoBERTa, as the CPT initialize the encoder from the checkpoint.
12+
13+
Then, use the scripts `run_pretrain_bart.sh` and `run_pretrain_cpt.sh` to train Chinese BART and CPT, respectively.
14+
15+
16+
*NOTE: the training scripts is distributed examples for 8 GPUs. You may alter the number of GPUs and change the training steps to meet the need.*
817

918
## Main Changes
1019
- Add `bart_model` and `cpt_model` for Megatron under `megatron/model`, to let Megatron can train on BART and CPT.
11-
- Add `_HfAutoTokenizer` in `megatron/tokenizer/tokenizer.py` to let Megatron can use Tokenizers from Huggingface-Transformers.
20+
- Add `_HfBertTokenizer` in `megatron/tokenizer/tokenizer.py` to let Megatron can use Tokenizers from Huggingface-Transformers.
1221
- Add `bart_dataset` and `cpt_dataset` under `megatron/data` to produce data for Whole Word Masking (WWM) and Denoising Auto-Encoder (DAE) pre-training.
1322
- Add `tools/convert_ckpt.py` to convert Megatron checkpoints to Huggingface-Transformers format.
1423
- Add `tools/preprocess_data.py` to preprocess and chunk large amount of text data into binary format used in Megatron.

pretrain/megatron/model/bart_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def __init__(self):
2525

2626
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
2727

28-
# config = BartConfig.from_pretrained(args.vocab_file) # vocab file path also contains config.json
29-
config = BartConfig.from_pretrained('vocab/bart_zh_vocab')
28+
config = BartConfig.from_pretrained(args.vocab_file) # vocab file path also contains config.json
3029
# encoder_config = BertConfig.from_pretrained(model_path)
3130
tokenizer = get_tokenizer()
3231
config.vocab_size = tokenizer.vocab_size

pretrain/megatron/model/cpt_model.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,19 @@
1414
from megatron.model.utils import init_method_normal
1515
from megatron.model.utils import scaled_init_method_normal
1616
from .module import MegatronModule
17-
# from transformers import BartForConditionalGeneration as HFBartModel
18-
# from megatron.model.modeling_bart import BartForConditionalGeneration as HFBartModel
19-
from megatron.model.modeling_cpt import BartForConditionalGeneration as HFBartModel
17+
from megatron.model.modeling_cpt import CPTForConditionalGeneration as HFBartModel
2018
from transformers import BertConfig, BartConfig
2119

22-
class BartModel(MegatronModule):
20+
class CPTModel(MegatronModule):
2321
def __init__(self):
2422
super().__init__()
2523
args = get_args()
2624

2725
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
2826

2927
# config = BartConfig.from_pretrained(args.vocab_file) # vocab file path also contains config.json
30-
if args.num_layers > 12:
31-
model_path = 'roberta-zh/large'
32-
else:
33-
model_path = 'roberta-zh/base'
34-
config = BartConfig.from_pretrained('vocab/bart_zh_vocab')
28+
model_path = 'roberta_zh'
29+
config = BartConfig.from_pretrained(args.vocab_file) # vocab file path also contains config.json
3530
# encoder_config = BertConfig.from_pretrained(model_path)
3631
tokenizer = get_tokenizer()
3732
config.vocab_size = tokenizer.vocab_size

pretrain/megatron/tokenizer/tokenizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from .bert_tokenization import FullTokenizer as FullBertTokenizer
2222
from .gpt2_tokenization import GPT2Tokenizer
23-
from transformers import AutoTokenizer
23+
from transformers import BertTokenizer
2424

2525
def build_tokenizer(args):
2626
"""Initialize tokenizer."""
@@ -42,7 +42,7 @@ def build_tokenizer(args):
4242
assert args.merge_file is not None
4343
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
4444
elif args.tokenizer_type == 'Huggingface':
45-
tokenizer = _HfAutoTokenizer(args.vocab_file)
45+
tokenizer = _HfBertTokenizer(args.vocab_file)
4646
else:
4747
raise NotImplementedError('{} tokenizer is not '
4848
'implemented.'.format(args.tokenizer_type))
@@ -128,9 +128,9 @@ def mask(self):
128128
'tokenizer'.format(self.name))
129129

130130

131-
class _HfAutoTokenizer(AbstractTokenizer):
131+
class _HfBertTokenizer(AbstractTokenizer):
132132
def __init__(self, from_pretrained_path):
133-
self.tokenizer = AutoTokenizer.from_pretrained(from_pretrained_path)
133+
self.tokenizer = BertTokenizer.from_pretrained(from_pretrained_path)
134134
self.tokenizer_type = self.tokenizer.__class__.__name__
135135
self._inv_vocab = {i:t for t,i in self.tokenizer.get_vocab().items()}
136136
super().__init__('Huggingface Tokenizer {}'.format(self.tokenizer_type))

pretrain/pretrain_cpt.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
get_tokenizer
2828
)
2929
from megatron.data.dataset_utils import build_train_valid_test_datasets
30-
from megatron.model.cpt_model import BartModel
30+
from megatron.model.cpt_model import CPTModel
3131
from megatron.training import pretrain
3232
from megatron.utils import average_losses_across_data_parallel_group
3333

@@ -37,15 +37,16 @@ def model_provider(pre_process=True, post_process=True):
3737
assert pre_process and post_process, "BART doesn't yet support pipelining"
3838

3939
print_rank_0('building BART model ...')
40-
model = BartModel()
40+
model = CPTModel()
4141
print_rank_0(model)
4242
return model
4343

4444

45+
SHOW_DATA = False
4546
def get_batch(data_iterator):
4647
"""Build the batch."""
4748

48-
keys = ['source', 'target', 'prev_output_tokens', 'pos1', 'pos2', 'attn_mask', 'loss_mask']
49+
keys = ['source', 'target', 'prev_output_tokens', 'attn_mask', 'loss_mask', 'use_decoder']
4950
datatype = torch.int64
5051

5152
# Broadcast data.
@@ -59,17 +60,19 @@ def get_batch(data_iterator):
5960
source = data_b['source'].long()
6061
target = data_b['target'].long()
6162
prev_output_tokens = data_b['prev_output_tokens'].long()
62-
pos1 = data_b['pos1'].long()
63-
pos2 = data_b['pos2'].long()
6463
attn_mask = data_b['attn_mask'].long()
6564
loss_mask = data_b['loss_mask'].float()
66-
# print('source', source[0])
67-
# print('target', target[0])
68-
# tokenizer = get_tokenizer()
69-
# print('source', tokenizer.detokenize(source[0]))
70-
# print('target', tokenizer.detokenize(target[0]))
71-
# import pdb; pdb.set_trace()
72-
return source, target, prev_output_tokens, pos1, pos2, attn_mask, loss_mask
65+
use_decoder = data_b['use_decoder'].long()
66+
67+
global SHOW_DATA
68+
if not SHOW_DATA:
69+
SHOW_DATA = True
70+
print_rank_0('source: {}'.format(source[0]))
71+
print_rank_0('target: {}'.format(target[0]))
72+
tokenizer = get_tokenizer()
73+
print_rank_0('source: {}'.format(tokenizer.detokenize(source[0])))
74+
print_rank_0('target: {}'.format(tokenizer.detokenize(target[0])))
75+
return source, target, prev_output_tokens, attn_mask, loss_mask, use_decoder
7376

7477

7578
def loss_func(loss_mask, output_tensor):
@@ -94,11 +97,11 @@ def forward_step(data_iterator, model):
9497

9598
# Get the batch.
9699
timers('batch-generator').start()
97-
source, target, prev_output_tokens, pos1, pos2, attn_mask, loss_mask = get_batch(data_iterator)
100+
source, target, prev_output_tokens, attn_mask, loss_mask, use_decoder = get_batch(data_iterator)
98101
timers('batch-generator').stop()
99102

100103
# Forward model lm_labels
101-
output_tensor = model(source, attn_mask, prev_output_tokens, pos1, pos2, target)
104+
output_tensor = model(source, attn_mask, prev_output_tokens, target, use_decoder)
102105

103106
return output_tensor, partial(loss_func, loss_mask)
104107

pretrain/run_pretrain_bart.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
1010

1111
DATA_PATH="dataset/"
1212
CHECKPOINT_PATH=checkpoints/bart-base
13-
VOCAB_FILE=vocab/bert_zh_vocab/
13+
VOCAB_FILE=vocab/
1414

1515
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
1616

@@ -39,11 +39,11 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
3939
--initial-loss-scale 65536 \
4040
--weight-decay 1e-2 \
4141
--clip-grad 1.0 \
42-
--lr-warmup-fraction .032 \
42+
--lr-warmup-fraction .01 \
4343
--log-interval 1 \
4444
--save-interval 1600 \
4545
--eval-interval 500 \
46-
--eval-iters 1 \
46+
--eval-iters 10 \
4747
--fp16 \
4848
--optimizer adam \
4949
--num-workers 2 \

pretrain/run_pretrain_cpt.sh

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
1010

1111
DATA_PATH="dataset/"
1212
CHECKPOINT_PATH=checkpoints/cpt-base
13-
VOCAB_FILE=vocab/bert_zh_vocab/
13+
VOCAB_FILE=vocab/
1414

1515
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
1616

@@ -20,14 +20,13 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
2020
--num-decoder-layers 2 \
2121
--hidden-size 768 \
2222
--num-attention-heads 12 \
23-
--micro-batch-size 32 \
24-
--global-batch-size 512 \
23+
--micro-batch-size 16 \
24+
--global-batch-size 256 \
2525
--seq-length 512 \
2626
--max-position-embeddings 512 \
2727
--mask-prob 0.15 \
28-
--train-iters 1000000 \
29-
--lr-decay-iters 1000000 \
30-
--lr-warmup-fraction .01 \
28+
--train-iters 100000 \
29+
--lr-decay-iters 100000 \
3130
--save $CHECKPOINT_PATH \
3231
--load $CHECKPOINT_PATH \
3332
--data-path $DATA_PATH \
@@ -36,16 +35,17 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
3635
--split 949,30,1 \
3736
--distributed-backend nccl \
3837
--lr 1e-4 \
39-
--lr-encoder 5e-5 \
4038
--lr-decay-style cosine \
4139
--min-lr 1e-6 \
40+
--initial-loss-scale 65536 \
4241
--weight-decay 1e-2 \
4342
--clip-grad 1.0 \
44-
--initial-loss-scale 65536 \
45-
--log-interval 10 \
46-
--save-interval 10000 \
43+
--lr-warmup-fraction .01 \
44+
--log-interval 1 \
45+
--save-interval 1600 \
4746
--eval-interval 500 \
4847
--eval-iters 10 \
49-
--num-workers 2 \
5048
--fp16 \
49+
--optimizer adam \
50+
--num-workers 2 \
5151
# --checkpoint-activations

0 commit comments

Comments
 (0)