Skip to content

Commit f137f6b

Browse files
committed
fix viterbi max_seq_length bug
1 parent c176fca commit f137f6b

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

paddle/fluid/operators/viterbi_decode_op.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ class ViterbiDecodeOp : public framework::OperatorWithKernel {
6868
"The number of tags of Input and Transition should be equal."));
6969

7070
ctx->SetOutputDim("Scores", length_dims);
71-
ctx->SetOutputDim("Path", framework::make_ddim({in_dims[0], in_dims[1]}));
7271
}
7372

7473
protected:

paddle/fluid/operators/viterbi_decode_op.h

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,6 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
109109
auto seq_len = static_cast<int>(input->dims()[1]);
110110
auto n_labels = static_cast<int>(input->dims()[2]);
111111

112-
auto* scores = ctx.Output<Tensor>("Scores");
113-
auto* path = ctx.Output<Tensor>("Path");
114-
scores->mutable_data<T>(curr_place);
115-
path->mutable_data<int64_t>(curr_place);
116-
117112
// Create a large int data buffer
118113
int buffer_size = batch_size * seq_len +
119114
batch_size * n_labels * (seq_len - 1) + 7 * batch_size +
@@ -128,7 +123,23 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
128123
CREATE_TENSOR(float_buffer, T, buffer_size);
129124
TensorBuffer float_tensor_buffer(float_buffer);
130125

131-
Tensor temp_path = int_tensor_buffer.GetBufferBlock({seq_len, batch_size});
126+
auto* length = ctx.Input<Tensor>("Length");
127+
Tensor left_length = int_tensor_buffer.GetBufferBlock({batch_size, 1});
128+
framework::TensorCopy(*length, curr_place, dev_ctx, &left_length);
129+
130+
int64_t max_seq_len =
131+
*std::max_element(left_length.data<int64_t>(),
132+
left_length.data<int64_t>() + left_length.numel());
133+
134+
auto* scores = ctx.Output<Tensor>("Scores");
135+
scores->mutable_data<T>(curr_place);
136+
137+
auto* path = ctx.Output<Tensor>("Path");
138+
path->Resize({batch_size, max_seq_len});
139+
path->mutable_data<int64_t>(curr_place);
140+
141+
Tensor temp_path =
142+
int_tensor_buffer.GetBufferBlock({max_seq_len, batch_size});
132143
auto batch_path = Unbind(temp_path);
133144
for (auto it = batch_path.begin(); it != batch_path.end(); ++it) {
134145
it->Resize({batch_size});
@@ -145,14 +156,6 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
145156
float_tensor_buffer.GetBufferBlock({1, n_labels, n_labels});
146157
framework::TensorCopy(*transition, curr_place, dev_ctx, &trans_exp);
147158

148-
auto* length = ctx.Input<Tensor>("Length");
149-
Tensor left_length = int_tensor_buffer.GetBufferBlock({batch_size, 1});
150-
framework::TensorCopy(*length, curr_place, dev_ctx, &left_length);
151-
152-
int64_t max_seq_len =
153-
*std::max_element(left_length.data<int64_t>(),
154-
left_length.data<int64_t>() + left_length.numel());
155-
156159
Tensor alpha = float_tensor_buffer.GetBufferBlock({batch_size, n_labels});
157160
math::SetConstant<platform::CPUDeviceContext, T> float_functor;
158161
math::SetConstant<platform::CPUDeviceContext, int64_t> int_functor;
@@ -278,7 +281,9 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
278281

279282
// last_ids_update = last_ids * tag_mask
280283
int last_ids_index = 1;
281-
MUL(last_ids, int_mask, batch_path[seq_len - last_ids_index], int64_t);
284+
int actual_len = std::min(seq_len, static_cast<int>(max_seq_len));
285+
286+
MUL(last_ids, int_mask, batch_path[actual_len - last_ids_index], int64_t);
282287
int64_t* batch_offset_ptr = batch_offset.data<int64_t>();
283288
for (int64_t i = 0; i < batch_size; ++i) {
284289
batch_offset_ptr[i] = i * n_labels;
@@ -290,7 +295,7 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
290295
ADD(batch_offset, last_ids, gather_idx, int64_t);
291296
// tag_mask = paddle.cast((left_length >= 0), 'int64')
292297
// last_ids_update = paddle.gather(hist.flatten(), gather_idx) * tag_mask
293-
Tensor& last_ids_update = batch_path[seq_len - last_ids_index];
298+
Tensor& last_ids_update = batch_path[actual_len - last_ids_index];
294299
hist->Resize({batch_size * n_labels});
295300
CPUGather<int64_t, int64_t>(dev_ctx, *hist, gather_idx, &last_ids_update);
296301
GET_CAST_MASK(left_length, zero, tag_mask, int_mask, GreaterEqualFunctor,

0 commit comments

Comments
 (0)