Skip to content

Commit b97e6d1

Browse files
authored
[phi] move viterbi_decode to phi (#40186)
* move viterbi to phi * move infershape to phi * update infershape * fix * resolve conflicts
1 parent 452c75b commit b97e6d1

File tree

9 files changed

+953
-690
lines changed

9 files changed

+953
-690
lines changed

paddle/fluid/operators/viterbi_decode_op.cc

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
See the License for the specific language governing permissions and
1010
limitations under the License. */
1111

12-
#include "paddle/fluid/operators/viterbi_decode_op.h"
12+
#include "paddle/fluid/framework/infershape_utils.h"
1313
#include "paddle/fluid/framework/op_registry.h"
14+
#include "paddle/phi/core/infermeta_utils.h"
15+
#include "paddle/phi/infermeta/ternary.h"
1416

1517
namespace paddle {
1618
namespace operators {
@@ -19,47 +21,6 @@ class ViterbiDecodeOp : public framework::OperatorWithKernel {
1921
public:
2022
using framework::OperatorWithKernel::OperatorWithKernel;
2123

22-
void InferShape(framework::InferShapeContext* ctx) const override {
23-
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ViterbiDecode");
24-
OP_INOUT_CHECK(ctx->HasInput("Transition"), "Input", "Transition",
25-
"ViterbiDecode");
26-
OP_INOUT_CHECK(ctx->HasInput("Length"), "Input", "Length", "ViterbiDecode");
27-
OP_INOUT_CHECK(ctx->HasOutput("Scores"), "Output", "Scores",
28-
"ViterbiDecode");
29-
OP_INOUT_CHECK(ctx->HasOutput("Path"), "Output", "Path", "ViterbiDecode");
30-
auto in_dims = ctx->GetInputDim("Input");
31-
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
32-
platform::errors::InvalidArgument(
33-
"The rank of Input in ViterbiDecode must be 3. But "
34-
"received Input's rank is %d.",
35-
in_dims.size()));
36-
auto length_dims = ctx->GetInputDim("Length");
37-
PADDLE_ENFORCE_EQ(length_dims.size(), 1,
38-
platform::errors::InvalidArgument(
39-
"The rank of Length in ViterbiDecode must be 1. But "
40-
"received Length's rank is %d.",
41-
length_dims.size()));
42-
auto transition_dims = ctx->GetInputDim("Transition");
43-
PADDLE_ENFORCE_EQ(
44-
transition_dims.size(), 2,
45-
platform::errors::InvalidArgument(
46-
"The rank of Transition in ViterbiDecode must be 2. But "
47-
"received Transition's rank is %d.",
48-
transition_dims.size()));
49-
if (ctx->IsRuntime()) {
50-
PADDLE_ENFORCE_EQ(
51-
in_dims[0], length_dims[0],
52-
platform::errors::InvalidArgument(
53-
"The batch size of Input and Length should be equal."));
54-
PADDLE_ENFORCE_EQ(in_dims[2], transition_dims[0],
55-
platform::errors::InvalidArgument(
56-
"The number of tags of Input (%d) and Transition "
57-
"(%d) should be equal.",
58-
transition_dims[0], in_dims[2]));
59-
}
60-
ctx->SetOutputDim("Scores", length_dims);
61-
}
62-
6324
protected:
6425
framework::OpKernelType GetExpectedKernelType(
6526
const framework::ExecutionContext& ctx) const override {
@@ -102,8 +63,8 @@ class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
10263

10364
namespace ops = paddle::operators;
10465
namespace platform = paddle::platform;
66+
DECLARE_INFER_SHAPE_FUNCTOR(viterbi_decode, ViterbiDecodeInferShapeFunctor,
67+
PD_INFER_META(phi::ViterbiDecodeInferMeta));
10568
REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode, ops::ViterbiDecodeOp,
106-
ops::ViterbiDecodeOpMaker);
107-
REGISTER_OP_CPU_KERNEL(
108-
viterbi_decode, ops::ViterbiDecodeKernel<platform::CPUDeviceContext, float>,
109-
ops::ViterbiDecodeKernel<platform::CPUDeviceContext, double>);
69+
ops::ViterbiDecodeOpMaker,
70+
ViterbiDecodeInferShapeFunctor);

paddle/fluid/operators/viterbi_decode_op.cu

Lines changed: 0 additions & 206 deletions
This file was deleted.

0 commit comments

Comments
 (0)