Skip to content

Commit 38f4b1d

Browse files
authored
Merge pull request #3430 from wangkuiyi/add_operatorbase_constructors
Add constructors to OperatorBase and all sub-classes
2 parents 19ab1dc + 65bd7c7 commit 38f4b1d

19 files changed

+71
-0
lines changed

paddle/framework/backward_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ using DeviceContext = platform::DeviceContext;
3030

3131
class EmptyOp : public OperatorBase {
3232
public:
33+
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase)
34+
3335
void InferShape(const Scope &scope) const override {}
3436
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
3537
};

paddle/framework/grad_op_builder_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ namespace framework {
1010

1111
class NOP : public OperatorBase {
1212
public:
13+
DEFINE_OPERATOR_CTOR(NOP, OperatorBase)
14+
1315
void InferShape(const Scope &scope) const override {}
1416
void Run(const Scope &scope,
1517
const platform::DeviceContext &dev_ctx) const override {}

paddle/framework/op_registry_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ namespace paddle {
77
namespace framework {
88
class CosineOp : public OperatorBase {
99
public:
10+
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase)
11+
1012
void Run(const Scope& scope,
1113
const platform::DeviceContext& dev_ctx) const override {}
1214
void InferShape(const Scope& scope) const override {}
@@ -27,6 +29,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
2729

2830
class MyTestOp : public OperatorBase {
2931
public:
32+
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase)
33+
3034
void InferShape(const Scope& scope) const override {}
3135
void Run(const Scope& scope,
3236
const platform::DeviceContext& dev_ctx) const override {}

paddle/framework/operator.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ class ExecutionContext;
6363
*/
6464
class OperatorBase {
6565
public:
66+
OperatorBase() {} // TODO(yi): This constructor is to be removed.
67+
OperatorBase(const std::string& type, const std::vector<std::string>& inputs,
68+
const std::vector<std::string>& outputs,
69+
const AttributeMap& attrs,
70+
std::unordered_map<std::string, int>* in_out_idxs)
71+
: type_(type),
72+
inputs_(inputs),
73+
outputs_(outputs),
74+
attrs_(attrs),
75+
in_out_idxs_(in_out_idxs) {}
76+
6677
virtual ~OperatorBase() {}
6778

6879
template <typename T>
@@ -109,6 +120,9 @@ class OperatorBase {
109120
const std::vector<std::string> Inputs() const { return inputs_; }
110121
const std::vector<std::string> Outputs() const { return outputs_; }
111122
const AttributeMap& Attrs() const { return attrs_; }
123+
const std::unordered_map<std::string, int>* InOutIdx() const {
124+
return in_out_idxs_.get();
125+
}
112126

113127
public:
114128
std::string type_;
@@ -286,6 +300,14 @@ class OpKernel {
286300

287301
class OperatorWithKernel : public OperatorBase {
288302
public:
303+
OperatorWithKernel() {} // TODO(yi): This constructor is to be removed.
304+
OperatorWithKernel(const std::string& type,
305+
const std::vector<std::string>& inputs,
306+
const std::vector<std::string>& outputs,
307+
const AttributeMap& attrs,
308+
std::unordered_map<std::string, int>* in_out_idxs)
309+
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
310+
289311
struct OpKernelKey {
290312
platform::Place place_;
291313

@@ -335,5 +357,15 @@ class OperatorWithKernel : public OperatorBase {
335357
virtual void InferShape(const InferShapeContext& ctx) const = 0;
336358
};
337359

360+
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
361+
public: \
362+
Class() { /* TODO(yi): This constructor is to be removed. */ \
363+
} \
364+
Class(const std::string& type, const std::vector<std::string>& inputs, \
365+
const std::vector<std::string>& outputs, \
366+
const ::paddle::framework::AttributeMap& attrs, \
367+
std::unordered_map<std::string, int>* in_out_idxs) \
368+
: ParentClass(type, inputs, outputs, attrs, in_out_idxs) {}
369+
338370
} // namespace framework
339371
} // namespace paddle

paddle/framework/operator_test.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ static int op_run_num = 0;
2323

2424
class OpWithoutKernelTest : public OperatorBase {
2525
public:
26+
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, OperatorBase)
27+
2628
void Init() override { x = 1; }
2729
void InferShape(const Scope& scope) const override {}
2830
void Run(const Scope& scope,
@@ -97,6 +99,8 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
9799
static int cpu_kernel_run_num = 0;
98100

99101
class OpWithKernelTest : public OperatorWithKernel {
102+
public:
103+
DEFINE_OPERATOR_CTOR(OpWithKernelTest, OperatorWithKernel)
100104
protected:
101105
void InferShape(const framework::InferShapeContext& ctx) const override {}
102106
};
@@ -116,6 +120,8 @@ class CPUKernelTest : public OpKernel {
116120
// multiple inputs test
117121
class OperatorMultiInputsTest : public OperatorBase {
118122
public:
123+
DEFINE_OPERATOR_CTOR(OperatorMultiInputsTest, OperatorBase)
124+
119125
void Init() override { x = 1; }
120126
void InferShape(const Scope& scope) const override {}
121127
void Run(const Scope& scope,

paddle/operators/add_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace paddle {
1818
namespace operators {
1919

2020
class AddOp : public framework::OperatorWithKernel {
21+
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
2122
protected:
2223
void InferShape(const framework::InferShapeContext &ctx) const override {
2324
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2);
@@ -47,6 +48,7 @@ The equation is: Out = X + Y
4748
};
4849

4950
class AddOpGrad : public framework::OperatorWithKernel {
51+
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
5052
protected:
5153
void InferShape(const framework::InferShapeContext &ctx) const override {}
5254
};

paddle/operators/cross_entropy_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace paddle {
1818
namespace operators {
1919

2020
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
21+
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
2122
protected:
2223
void InferShape(const framework::InferShapeContext &ctx) const override {
2324
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2,
@@ -38,6 +39,8 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
3839
};
3940

4041
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
42+
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
43+
framework::OperatorWithKernel)
4144
protected:
4245
void InferShape(const framework::InferShapeContext &ctx) const override {
4346
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));

paddle/operators/fill_zeros_like_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace paddle {
1818
namespace operators {
1919

2020
class FillZerosLikeOp : public framework::OperatorWithKernel {
21+
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel)
2122
protected:
2223
void InferShape(const framework::InferShapeContext &ctx) const override {
2324
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,

paddle/operators/gaussian_random_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class GaussianRandomKernel : public framework::OpKernel {
4343
};
4444

4545
class GaussianRandomOp : public framework::OperatorWithKernel {
46+
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel)
4647
protected:
4748
void InferShape(const framework::InferShapeContext& context) const override {
4849
auto* tensor = context.Output<framework::Tensor>(0);

paddle/operators/mean_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace paddle {
1818
namespace operators {
1919

2020
class MeanOp : public framework::OperatorWithKernel {
21+
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
2122
protected:
2223
void InferShape(const framework::InferShapeContext &ctx) const override {
2324
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one");
@@ -39,6 +40,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
3940
};
4041

4142
class MeanGradOp : public framework::OperatorWithKernel {
43+
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
4244
protected:
4345
void InferShape(const framework::InferShapeContext &ctx) const override {
4446
ctx.Output<Tensor>(framework::GradVarName("X"))

0 commit comments

Comments
 (0)