@@ -9,8 +9,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99See the License for the specific language governing permissions and
1010limitations under the License. */
1111
12- #include " paddle/fluid/operators/viterbi_decode_op .h"
12+ #include " paddle/fluid/framework/infershape_utils .h"
1313#include " paddle/fluid/framework/op_registry.h"
14+ #include " paddle/phi/core/infermeta_utils.h"
15+ #include " paddle/phi/infermeta/ternary.h"
1416
1517namespace paddle {
1618namespace operators {
@@ -19,47 +21,6 @@ class ViterbiDecodeOp : public framework::OperatorWithKernel {
1921 public:
2022 using framework::OperatorWithKernel::OperatorWithKernel;
2123
22- void InferShape (framework::InferShapeContext* ctx) const override {
23- OP_INOUT_CHECK (ctx->HasInput (" Input" ), " Input" , " Input" , " ViterbiDecode" );
24- OP_INOUT_CHECK (ctx->HasInput (" Transition" ), " Input" , " Transition" ,
25- " ViterbiDecode" );
26- OP_INOUT_CHECK (ctx->HasInput (" Length" ), " Input" , " Length" , " ViterbiDecode" );
27- OP_INOUT_CHECK (ctx->HasOutput (" Scores" ), " Output" , " Scores" ,
28- " ViterbiDecode" );
29- OP_INOUT_CHECK (ctx->HasOutput (" Path" ), " Output" , " Path" , " ViterbiDecode" );
30- auto in_dims = ctx->GetInputDim (" Input" );
31- PADDLE_ENFORCE_EQ (in_dims.size (), 3 ,
32- platform::errors::InvalidArgument (
33- " The rank of Input in ViterbiDecode must be 3. But "
34- " received Input's rank is %d." ,
35- in_dims.size ()));
36- auto length_dims = ctx->GetInputDim (" Length" );
37- PADDLE_ENFORCE_EQ (length_dims.size (), 1 ,
38- platform::errors::InvalidArgument (
39- " The rank of Length in ViterbiDecode must be 1. But "
40- " received Length's rank is %d." ,
41- length_dims.size ()));
42- auto transition_dims = ctx->GetInputDim (" Transition" );
43- PADDLE_ENFORCE_EQ (
44- transition_dims.size (), 2 ,
45- platform::errors::InvalidArgument (
46- " The rank of Transition in ViterbiDecode must be 2. But "
47- " received Transition's rank is %d." ,
48- transition_dims.size ()));
49- if (ctx->IsRuntime ()) {
50- PADDLE_ENFORCE_EQ (
51- in_dims[0 ], length_dims[0 ],
52- platform::errors::InvalidArgument (
53- " The batch size of Input and Length should be equal." ));
54- PADDLE_ENFORCE_EQ (in_dims[2 ], transition_dims[0 ],
55- platform::errors::InvalidArgument (
56- " The number of tags of Input (%d) and Transition "
57- " (%d) should be equal." ,
58- transition_dims[0 ], in_dims[2 ]));
59- }
60- ctx->SetOutputDim (" Scores" , length_dims);
61- }
62-
6324 protected:
6425 framework::OpKernelType GetExpectedKernelType (
6526 const framework::ExecutionContext& ctx) const override {
@@ -102,8 +63,8 @@ class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
10263
10364namespace ops = paddle::operators;
10465namespace platform = paddle::platform;
66+ DECLARE_INFER_SHAPE_FUNCTOR (viterbi_decode, ViterbiDecodeInferShapeFunctor,
67+ PD_INFER_META (phi::ViterbiDecodeInferMeta));
10568REGISTER_OP_WITHOUT_GRADIENT (viterbi_decode, ops::ViterbiDecodeOp,
106- ops::ViterbiDecodeOpMaker);
107- REGISTER_OP_CPU_KERNEL (
108- viterbi_decode, ops::ViterbiDecodeKernel<platform::CPUDeviceContext, float >,
109- ops::ViterbiDecodeKernel<platform::CPUDeviceContext, double >);
69+ ops::ViterbiDecodeOpMaker,
70+ ViterbiDecodeInferShapeFunctor);
0 commit comments