Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ba088c1
init nmt
Superjomn Nov 22, 2017
c5da31a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Dec 4, 2017
f5e4cb6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Dec 4, 2017
19dee22
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Dec 4, 2017
a95984b
encoder ready
Superjomn Dec 4, 2017
ec24923
only generation implementation
Superjomn Dec 5, 2017
fffceb8
Merge branch 'feature/nmt-on-while' into feature/nmt-model
Superjomn Dec 5, 2017
3be6422
init python
Superjomn Dec 6, 2017
cf56456
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Dec 6, 2017
d9a321c
remove decoder temporary
Superjomn Dec 6, 2017
4a0567f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Dec 7, 2017
a5b9399
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Dec 7, 2017
d068644
add implementation of decoder
Superjomn Dec 11, 2017
c5275a8
merged
Superjomn Dec 11, 2017
d5784f3
add
Superjomn Dec 11, 2017
a823d46
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Dec 11, 2017
20a105e
clean code
Superjomn Dec 11, 2017
71f5c72
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Dec 18, 2017
9128bb6
restore op_registry.h
Superjomn Dec 18, 2017
8648790
restore op_registry.h
Superjomn Dec 18, 2017
70a4af4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Dec 19, 2017
98dec8f
fix fluid python bugs
Superjomn Dec 20, 2017
eb252c4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Dec 21, 2017
ec42a61
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Superjomn Jan 2, 2018
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions paddle/operators/beam_search_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,40 @@ class BeamSearchProtoAndCheckerMaker
}
};

class BeamSearchInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
for (const std::string &arg :
std::vector<std::string>({"pre_ids", "ids", "scores"})) {
PADDLE_ENFORCE(context->HasInput(arg),
"BeamSearch need input argument '%s'", arg);
}
for (const std::string &arg :
std::vector<std::string>({"selected_ids", "selected_scores"})) {
PADDLE_ENFORCE(context->HasOutput(arg),
"BeamSearch need output argument '%s'", arg);
}
}
};

class BeamSearchInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override {
for (auto &o : op_desc.Output("selected_ids")) {
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR);
}
for (auto &o : op_desc.Output("selected_scores")) {
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR);
}
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(beam_search, paddle::operators::BeamSearchOp,
paddle::operators::BeamSearchProtoAndCheckerMaker);
REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp,
paddle::operators::BeamSearchProtoAndCheckerMaker,
paddle::operators::BeamSearchInferShape,
paddle::operators::BeamSearchInferVarType,
paddle::framework::EmptyGradOpMaker);
1 change: 1 addition & 0 deletions paddle/operators/sequence_expand_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
const T* x_data = x->data<T>();
auto x_dims = x->dims();
auto* y = context.Input<LoDTensor>("Y");
PADDLE_ENFORCE(!y->lod().empty(), "y should have lod");
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims[0]),
y->lod().back().size() - 1,
"The size of last lod level in Input(Y)"
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/v2/fluid/layer_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def input_dtype(self, input_param_name='input'):
if dtype is None:
dtype = each.dtype
elif dtype != each.dtype:
raise ValueError("Data Type mismatch")
raise ValueError("Data Type mismatch: %d to %d" %
(dtype, each.dtype))
return dtype

def create_parameter(self,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/v2/fluid/layers/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def max_sequence_len(rank_table):

def topk(input, k):
helper = LayerHelper('topk', **locals())
topk_out = helper.create_tmp_variable(dtype=input.data_type)
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype='int64')
helper.append_op(
type='top_k',
Expand Down
36 changes: 34 additions & 2 deletions python/paddle/v2/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy',
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min',
'sequence_first_step', 'sequence_last_step'
'beam_search', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max',
'reduce_min', 'sequence_first_step', 'sequence_last_step'
]


Expand Down Expand Up @@ -1130,6 +1130,38 @@ def sequence_expand(x, y):
return tmp


def beam_search(pre_ids, ids, scores, beam_size, end_id):
'''
This function implements the beam search algorithm.
'''
helper = LayerHelper('beam_search', **locals())
score_type = scores.dtype
id_type = ids.dtype

selected_scores = helper.create_tmp_variable(dtype=score_type)
selected_ids = helper.create_tmp_variable(dtype=id_type)

helper.append_op(
type='beam_search',
inputs={
'pre_ids': pre_ids,
'ids': ids,
'scores': scores,
},
outputs={
'selected_ids': selected_ids,
'selected_scores': selected_scores,
},
attrs={
# TODO(ChunweiYan) to assure other value support
'level': 0,
'beam_size': beam_size,
'end_id': end_id,
})

return selected_ids, selected_scores


def lstm_unit(x_t,
hidden_t_prev,
cell_t_prev,
Expand Down
179 changes: 154 additions & 25 deletions python/paddle/v2/fluid/tests/book/test_machine_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.layers as pd
from paddle.v2.fluid.executor import Executor

dict_size = 30000
Expand All @@ -13,52 +13,132 @@
word_dim = 16
IS_SPARSE = True
batch_size = 10
max_length = 50
max_length = 8
topk_size = 50
trg_dic_size = 10000
beam_size = 2

decoder_size = hidden_dim

place = core.CPUPlace()

def encoder_decoder():

def encoder():
# encoder
src_word_id = layers.data(
src_word_id = pd.data(
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
src_embedding = layers.embedding(
src_embedding = pd.embedding(
input=src_word_id,
size=[dict_size, word_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr(name='vemb'))

fc1 = fluid.layers.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
lstm_hidden0, lstm_0 = layers.dynamic_lstm(input=fc1, size=hidden_dim * 4)
encoder_out = layers.sequence_last_step(input=lstm_hidden0)
fc1 = pd.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
lstm_hidden0, lstm_0 = pd.dynamic_lstm(input=fc1, size=hidden_dim * 4)
encoder_out = pd.sequence_last_step(input=lstm_hidden0)
return encoder_out


def decoder_train(context):
# decoder
trg_language_word = layers.data(
trg_language_word = pd.data(
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
trg_embedding = layers.embedding(
trg_embedding = pd.embedding(
input=trg_language_word,
size=[dict_size, word_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr(name='vemb'))

rnn = fluid.layers.DynamicRNN()
rnn = pd.DynamicRNN()
with rnn.block():
current_word = rnn.step_input(trg_embedding)
mem = rnn.memory(init=encoder_out)
fc1 = fluid.layers.fc(input=[current_word, mem],
size=decoder_size,
act='tanh')
out = fluid.layers.fc(input=fc1, size=target_dict_dim, act='softmax')
mem = rnn.memory(init=context)
fc1 = pd.fc(input=[current_word, mem], size=decoder_size, act='tanh')
out = pd.fc(input=fc1, size=target_dict_dim, act='softmax')
rnn.update_memory(mem, fc1)
rnn.output(out)

return rnn()


def decoder_decode(context):
init_state = context
array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length)
counter = pd.zeros(shape=[1], dtype='int64')
# fill the first element with init_state
mem_array = pd.create_array('float32')
pd.array_write(init_state, array=mem_array, i=counter)

# ids, scores as memory
ids_array = pd.create_array('int64')
scores_array = pd.create_array('float32')

init_ids = pd.data(name="init_ids", shape=[1], dtype="int64", lod_level=1)
init_scores = pd.data(
name="init_scores", shape=[1], dtype="float32", lod_level=1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

初始化

batch_size 个

  • [0 1 2 3] LoD

state
state

# init_ids = pd.ones(shape=[batch_size, 1], dtype='int64')
# init_scores = pd.ones(shape=[batch_size, 1], dtype='float32')
# init ids to [1..]
# init scores to [1.]
pd.array_write(init_ids, array=ids_array, i=counter)
pd.array_write(init_scores, array=scores_array, i=counter)

cond = pd.less_than(x=counter, y=array_len)

while_op = pd.While(cond=cond)
with while_op.block():
pre_ids = pd.array_read(array=ids_array, i=counter)
pre_state = pd.array_read(array=scores_array, i=counter)
# id = pd.array_read(array=ids_array, i=counter)
target_word = pd.embedding(
input=pre_ids,
size=[dict_size, word_dim],
dtype='float32',
is_sparse=IS_SPARSE, )

# pre_state_expanded = pd.sequence_expand(pre_state, init_scores)
# print 'pre_state', pre_state
# print 'target_word', target_word
# print 'pre_state_expanded', pre_state_expanded

# use rnn unit to update rnn
# TODO share parameter with trainer
# updated_hidden = pd.fc(input=[target_word, pre_state_expanded],
# size=hidden_dim,
# act='tanh')
# scores = pd.fc(input=updated_hidden,
# size=target_dict_dim,
# act='softmax')

# topk_scores, topk_indices = pd.topk(scores, k=50)
# selected_ids, selected_scores = pd.beam_search(
# pre_ids, topk_indices, topk_scores, beam_size, end_id=1)

# # update the memories
# pd.array_write(updated_hidden, array=mem_array, i=counter)
# pd.array_write(selected_ids, array=ids_array, i=counter)
# pd.array_write(selected_scores, array=scores_array, i=counter)
pd.increment(x=counter, value=1, in_place=True)
pd.less_than(x=counter, y=array_len, cond=cond)

# translation_ids, translation_scores = pd.beam_search_decode(
# ids=ids_array, scores=scores_array)

# return init_ids, init_scores

# return translation_ids, translation_scores


def set_init_lod(data, lod, place):
res = core.LoDTensor()
res.set(data, place)
res.set_lod([lod])
return res


def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
Expand All @@ -74,12 +154,13 @@ def to_lodtensor(data, place):
return res


def main():
rnn_out = encoder_decoder()
label = layers.data(
def train_main():
context = encoder()
rnn_out = decoder_train(context)
label = pd.data(
name="target_language_next_word", shape=[1], dtype='int64', lod_level=1)
cost = layers.cross_entropy(input=rnn_out, label=label)
avg_cost = fluid.layers.mean(x=cost)
cost = pd.cross_entropy(input=rnn_out, label=label)
avg_cost = pd.mean(x=cost)

optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
optimizer.minimize(avg_cost)
Expand All @@ -89,13 +170,12 @@ def main():
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=batch_size)

place = core.CPUPlace()
exe = Executor(place)

exe.run(framework.default_startup_program())

batch_id = 0
for pass_id in xrange(2):
for pass_id in xrange(1):
for data in train_data():
word_data = to_lodtensor(map(lambda x: x[0], data), place)
trg_word = to_lodtensor(map(lambda x: x[1], data), place)
Expand All @@ -111,9 +191,58 @@ def main():
print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) +
" avg_cost=" + str(avg_cost_val))
if batch_id > 3:
exit(0)
break
batch_id += 1


def decode_main():
# a newe block
# with pd.BlockGuard() as block:
context = encoder()
# translation_ids, translation_scores = decoder_decode(context)
decoder_decode(context)
exe = Executor(place)
exe.run(framework.default_startup_program())

init_ids_data = np.array([1 for i in range(batch_size)], dtype='int64')
init_scores_data = np.array(
[1. for i in range(batch_size)], dtype='float32')
init_ids_data = init_ids_data.reshape((batch_size, 1))
init_scores_data = init_scores_data.reshape((batch_size, 1))
init_lod = [i for i in range(batch_size)] + [batch_size]

# print 'init_ids', init_ids
# print 'init_scores', init_scores

train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=batch_size)
for no, data in enumerate(train_data()):
init_ids = set_init_lod(init_ids_data, init_lod, place)
init_scores = set_init_lod(init_scores_data, init_lod, place)

print 'init_ids.dims', init_ids.dtype()
print 'init_scores.dims', init_scores.dtype()
print np.array(init_ids)
print np.array(init_scores)

word_data = to_lodtensor(map(lambda x: x[0], data), place)
trg_word = to_lodtensor(map(lambda x: x[1], data), place)
trg_word_next = to_lodtensor(map(lambda x: x[2], data), place)
exe.run(
framework.default_main_program(),
feed={
'src_word_id': word_data,
'init_ids': init_ids,
'init_scores': init_scores
},
#fetch_list=['init_ids', 'init_scores']
)
#fetch_list=[translation_ids, translation_scores])
break


if __name__ == '__main__':
main()
# train_main()
decode_main()