@@ -109,11 +109,6 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
109109 auto seq_len = static_cast <int >(input->dims ()[1 ]);
110110 auto n_labels = static_cast <int >(input->dims ()[2 ]);
111111
112- auto * scores = ctx.Output <Tensor>(" Scores" );
113- auto * path = ctx.Output <Tensor>(" Path" );
114- scores->mutable_data <T>(curr_place);
115- path->mutable_data <int64_t >(curr_place);
116-
117112 // Create a large int data buffer
118113 int buffer_size = batch_size * seq_len +
119114 batch_size * n_labels * (seq_len - 1 ) + 7 * batch_size +
@@ -128,7 +123,23 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
128123 CREATE_TENSOR (float_buffer, T, buffer_size);
129124 TensorBuffer float_tensor_buffer (float_buffer);
130125
131- Tensor temp_path = int_tensor_buffer.GetBufferBlock ({seq_len, batch_size});
126+ auto * length = ctx.Input <Tensor>(" Length" );
127+ Tensor left_length = int_tensor_buffer.GetBufferBlock ({batch_size, 1 });
128+ framework::TensorCopy (*length, curr_place, dev_ctx, &left_length);
129+
130+ int64_t max_seq_len =
131+ *std::max_element (left_length.data <int64_t >(),
132+ left_length.data <int64_t >() + left_length.numel ());
133+
134+ auto * scores = ctx.Output <Tensor>(" Scores" );
135+ scores->mutable_data <T>(curr_place);
136+
137+ auto * path = ctx.Output <Tensor>(" Path" );
138+ path->Resize ({batch_size, max_seq_len});
139+ path->mutable_data <int64_t >(curr_place);
140+
141+ Tensor temp_path =
142+ int_tensor_buffer.GetBufferBlock ({max_seq_len, batch_size});
132143 auto batch_path = Unbind (temp_path);
133144 for (auto it = batch_path.begin (); it != batch_path.end (); ++it) {
134145 it->Resize ({batch_size});
@@ -145,14 +156,6 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
145156 float_tensor_buffer.GetBufferBlock ({1 , n_labels, n_labels});
146157 framework::TensorCopy (*transition, curr_place, dev_ctx, &trans_exp);
147158
148- auto * length = ctx.Input <Tensor>(" Length" );
149- Tensor left_length = int_tensor_buffer.GetBufferBlock ({batch_size, 1 });
150- framework::TensorCopy (*length, curr_place, dev_ctx, &left_length);
151-
152- int64_t max_seq_len =
153- *std::max_element (left_length.data <int64_t >(),
154- left_length.data <int64_t >() + left_length.numel ());
155-
156159 Tensor alpha = float_tensor_buffer.GetBufferBlock ({batch_size, n_labels});
157160 math::SetConstant<platform::CPUDeviceContext, T> float_functor;
158161 math::SetConstant<platform::CPUDeviceContext, int64_t > int_functor;
@@ -278,7 +281,9 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
278281
279282 // last_ids_update = last_ids * tag_mask
280283 int last_ids_index = 1 ;
281- MUL (last_ids, int_mask, batch_path[seq_len - last_ids_index], int64_t );
284+ int actual_len = std::min (seq_len, static_cast <int >(max_seq_len));
285+
286+ MUL (last_ids, int_mask, batch_path[actual_len - last_ids_index], int64_t );
282287 int64_t * batch_offset_ptr = batch_offset.data <int64_t >();
283288 for (int64_t i = 0 ; i < batch_size; ++i) {
284289 batch_offset_ptr[i] = i * n_labels;
@@ -290,7 +295,7 @@ class ViterbiDecodeCPUKernel : public framework::OpKernel<T> {
290295 ADD (batch_offset, last_ids, gather_idx, int64_t );
291296 // tag_mask = paddle.cast((left_length >= 0), 'int64')
292297 // last_ids_update = paddle.gather(hist.flatten(), gather_idx) * tag_mask
293- Tensor& last_ids_update = batch_path[seq_len - last_ids_index];
298+ Tensor& last_ids_update = batch_path[actual_len - last_ids_index];
294299 hist->Resize ({batch_size * n_labels});
295300 CPUGather<int64_t , int64_t >(dev_ctx, *hist, gather_idx, &last_ids_update);
296301 GET_CAST_MASK (left_length, zero, tag_mask, int_mask, GreaterEqualFunctor,
0 commit comments