Skip to content

Commit 50b2361

Browse files
committed
refine comments.
1 parent 78e2952 commit 50b2361

File tree

4 files changed

+350
-184
lines changed

4 files changed

+350
-184
lines changed

fluid/NMT_Transformer/README.md

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
1-
# Transformer
1+
# Attention is All You Need: A Paddle Fluid implementation
22

3-
Set the model and training configurations in `config.py`, and execute `python train.py` to train.
3+
This is a Paddle Fluid implementation of the Transformer model in [Attention is All You Need]() (Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, arxiv, 2017).
44

5-
More details to be added.
5+
If you use the dataset/code in your research, please cite the paper:
6+
7+
```text
8+
@inproceedings{vaswani2017attention,
9+
title={Attention is all you need},
10+
author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, {\L}ukasz and Polosukhin, Illia},
11+
booktitle={Advances in Neural Information Processing Systems},
12+
pages={6000--6010},
13+
year={2017}
14+
}
15+
```
16+
17+
### TODO
18+
19+
This project is still under active development.

fluid/NMT_Transformer/config.py

Lines changed: 72 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,73 @@
1-
# Represent the dict sizes of source and target language. The dict from the
2-
# dataset here used includes the <bos>, <eos> and <unk> token but exlcudes
3-
# the <pad> token. It should plus 1 to include the padding token when used as
4-
# the size of lookup table.
5-
src_vocab_size = 10000
6-
trg_vocab_size = 10000
7-
# Represent the id of <pad> token in source language.
8-
src_pad_idx = src_vocab_size
9-
# Represent the id of <pad> token in target language.
10-
trg_pad_idx = trg_vocab_size
11-
# Represent the position value corresponding to the <pad> token.
12-
pos_pad_idx = 0
13-
# Represent the max length of sequences. It should plus 1 to include position
14-
# padding token for position encoding.
15-
max_length = 50
16-
# Represent the epoch number to train.
17-
pass_num = 2
18-
# Represent the number of sequences contained in a mini-batch.
19-
batch_size = 64
20-
# Reprent the params for Adam optimizer.
21-
learning_rate = 0.001
22-
beta1 = 0.9
23-
beta2 = 0.98
24-
eps = 1e-9
25-
# Represent the dimension of embeddings, which is also the last dimension of
26-
# the input and output of multi-head attention, position-wise feed-forward
27-
# networks, encoder and decoder.
28-
d_model = 512
29-
# Represent the size of the hidden layer in position-wise feed-forward networks.
30-
d_inner_hid = 1024
31-
# Represent the dimension keys are projected to for dot-product attention.
32-
d_key = 64
33-
# Represent the dimension values are projected to for dot-product attention.
34-
d_value = 64
35-
# Represent the number of head used in multi-head attention.
36-
n_head = 8
37-
# Represent the number of sub-layers to be stacked in the encoder and decoder.
38-
n_layer = 6
39-
# Represent the dropout rate used by all dropout layers.
40-
dropout = 0.1
41-
42-
# Names of position encoding table which will be initialized in external.
43-
pos_enc_param_names = ("src_pos_enc_table", "trg_pos_enc_table")
1+
class TrainTaskConfig(object):
2+
use_gpu = False
3+
# the epoch number to train.
4+
pass_num = 2
5+
6+
# number of sequences contained in a mini-batch.
7+
batch_size = 64
8+
9+
# the hyper params for Adam optimizer.
10+
learning_rate = 0.001
11+
beta1 = 0.9
12+
beta2 = 0.98
13+
eps = 1e-9
14+
15+
16+
class ModelHyperParams(object):
17+
# Dictionary size for source and target language. This model directly uses
18+
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has
19+
# alreay been added, but the <pad> token is not added. Transformer requires
20+
# sequences in a mini-batch are padded to have the same length. A <pad> token is
21+
# added into the original dictionary in paddle.dateset.wmt16.
22+
23+
# size of source word dictionary.
24+
src_vocab_size = 10000
25+
# index for <pad> token in source language.
26+
src_pad_idx = src_vocab_size
27+
28+
# size of target word dictionay
29+
trg_vocab_size = 10000
30+
# index for <pad> token in target language.
31+
trg_pad_idx = trg_vocab_size
32+
33+
# position value corresponding to the <pad> token.
34+
pos_pad_idx = 0
35+
36+
# max length of sequences. It should plus 1 to include position
37+
# padding token for position encoding.
38+
max_length = 50
39+
40+
# the dimension for word embeddings, which is also the last dimension of
41+
# the input and output of multi-head attention, position-wise feed-forward
42+
# networks, encoder and decoder.
43+
44+
d_model = 512
45+
# size of the hidden layer in position-wise feed-forward networks.
46+
d_inner_hid = 1024
47+
# the dimension that keys are projected to for dot-product attention.
48+
d_key = 64
49+
# the dimension that values are projected to for dot-product attention.
50+
d_value = 64
51+
# number of head used in multi-head attention.
52+
n_head = 8
53+
# number of sub-layers to be stacked in the encoder and decoder.
54+
n_layer = 6
55+
# dropout rate used by all dropout layers.
56+
dropout = 0.1
57+
58+
59+
# Names of position encoding table which will be initialized externally.
60+
pos_enc_param_names = (
61+
"src_pos_enc_table",
62+
"trg_pos_enc_table", )
63+
4464
# Names of all data layers listed in order.
45-
input_data_names = ("src_word", "src_pos", "trg_word", "trg_pos",
46-
"src_slf_attn_bias", "trg_slf_attn_bias",
47-
"trg_src_attn_bias", "lbl_word")
65+
input_data_names = (
66+
"src_word",
67+
"src_pos",
68+
"trg_word",
69+
"trg_pos",
70+
"src_slf_attn_bias",
71+
"trg_slf_attn_bias",
72+
"trg_src_attn_bias",
73+
"lbl_word", )

0 commit comments

Comments
 (0)