Skip to content

Commit 0ff8192

Browse files
committed
Add OperatorWithKernel class
* User can register OpKernel to its Ops. The OpKernelMap saved in OperatorWithKernel. Each Op which inherits OperatorWithKernel will use `OpKernel::Compute` instead of Run.
1 parent 2749b71 commit 0ff8192

File tree

9 files changed

+127
-153
lines changed

9 files changed

+127
-153
lines changed

paddle/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ if(Boost_FOUND)
1515
add_subdirectory(memory)
1616
add_subdirectory(platform)
1717
add_subdirectory(framework)
18-
add_subdirectory(operators)
1918
add_subdirectory(pybind)
2019
endif()
2120

paddle/framework/op_registry_test.cc

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
#include "paddle/framework/op_registry.h"
22
#include <gtest/gtest.h>
3-
#include "paddle/framework/operator.h"
4-
#include "paddle/operators/demo_op.h"
53

64
using namespace paddle::framework;
75

86
namespace paddle {
97
namespace framework {
10-
class CosineOp : public OperatorWithKernel {
8+
class CosineOp : public OperatorBase {
119
public:
12-
void Run(const OpRunContext* context) const override {
13-
printf("%s\n", DebugString().c_str());
14-
}
10+
void Run(const std::shared_ptr<Scope>& scope,
11+
const platform::DeviceContext& dev_ctx) const override {}
12+
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
1513
};
1614

1715
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
3028

3129
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
3230

33-
class MyTestOp : public OperatorWithKernel {
31+
class MyTestOp : public OperatorBase {
32+
public:
33+
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
34+
void Run(const std::shared_ptr<Scope>& scope,
35+
const platform::DeviceContext& dev_ctx) const override {}
36+
3437
public:
35-
void Run(const OpRunContext* ctx) const override {
36-
printf("%s\n", DebugString().c_str());
37-
printf("test_attr = %d\n", ctx->op_->GetAttr<int>("test_attr"));
38-
}
3938
};
4039

4140
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
7372
paddle::framework::OperatorBase* op =
7473
paddle::framework::OpRegistry::CreateOp(op_desc);
7574
auto scope = std::make_shared<Scope>();
76-
auto dev_ctx = DeviceContext();
77-
op->Run(scope, &dev_ctx);
75+
paddle::platform::CPUDeviceContext dev_ctx;
76+
op->Run(scope, dev_ctx);
7877
float scale_get = op->GetAttr<float>("scale");
7978
ASSERT_EQ(scale_get, scale);
8079
}
@@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
116115
paddle::framework::OperatorBase* op =
117116
paddle::framework::OpRegistry::CreateOp(op_desc);
118117
auto scope = std::make_shared<Scope>();
119-
auto dev_ctx = DeviceContext();
120-
op->Run(scope, &dev_ctx);
118+
paddle::platform::CPUDeviceContext dev_ctx;
119+
op->Run(scope, dev_ctx);
121120
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
122121
}
123122

@@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
169168
attr->set_i(4);
170169
paddle::framework::OperatorBase* op =
171170
paddle::framework::OpRegistry::CreateOp(op_desc);
172-
auto dev_ctx = DeviceContext();
171+
paddle::platform::CPUDeviceContext dev_ctx;
173172
auto scope = std::make_shared<Scope>();
174-
op->Run(scope, &dev_ctx);
173+
op->Run(scope, dev_ctx);
175174
int test_attr = op->GetAttr<int>("test_attr");
176175
ASSERT_EQ(test_attr, 4);
177176
}

paddle/framework/operator.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
3939
return ss.str();
4040
}
4141

42-
const Variable* OpRunContext::Input(int index) const {
43-
return scope_->GetVariable(op_->inputs_[index]);
44-
}
45-
46-
Variable* OpRunContext::Output(int index) const {
47-
return scope_->GetVariable(op_->outputs_[index]);
48-
}
49-
5042
} // namespace framework
5143
} // namespace paddle

paddle/framework/operator.h

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,22 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <paddle/framework/attr_checker.h>
18+
#include <paddle/framework/op_desc.pb.h>
19+
#include <paddle/framework/scope.h>
20+
#include <paddle/platform/device_context.h>
21+
#include <paddle/platform/place.h>
22+
#include <paddle/utils/Error.h>
1723
#include <boost/variant.hpp>
1824
#include <string>
1925
#include <unordered_map>
2026
#include <vector>
2127

22-
#include "paddle/framework/attr_checker.h"
23-
#include "paddle/framework/op_desc.pb.h"
24-
#include "paddle/framework/scope.h"
25-
#include "paddle/utils/Error.h"
26-
2728
namespace paddle {
2829
namespace framework {
2930

3031
class OperatorBase;
3132

32-
class DeviceContext {};
33-
34-
/**
35-
* OpRunContext is the only parameter of Operator's Run function.
36-
* Run will get input/output variables, state such as momentum and
37-
* device resource such as CUDA stream, cublas handle, etc. from
38-
* OpRunContext. User should construct it before run the Operator.
39-
*/
40-
class OpRunContext {
41-
public:
42-
OpRunContext(const OperatorBase* op, const std::shared_ptr<Scope> scope,
43-
const DeviceContext* device_context)
44-
: op_(op), scope_(scope), device_context_(device_context) {}
45-
46-
const Variable* Input(int index) const;
47-
Variable* Output(int index) const;
48-
49-
public:
50-
const OperatorBase* op_;
51-
const std::shared_ptr<Scope> scope_;
52-
const DeviceContext* device_context_;
53-
};
54-
5533
/**
5634
* OperatorBase has the basic element that Net will call to do computation.
5735
* Only CreateOperator from OpRegistry will new Operator directly. User
@@ -77,7 +55,10 @@ class OperatorBase {
7755

7856
/// Net will call this function to Run an op.
7957
virtual void Run(const std::shared_ptr<Scope>& scope,
80-
const DeviceContext* dev_ctx) const = 0;
58+
const platform::DeviceContext& dev_ctx) const = 0;
59+
60+
protected:
61+
std::string Type() const { return desc_.type(); }
8162

8263
public:
8364
OpDesc desc_;
@@ -86,22 +67,84 @@ class OperatorBase {
8667
AttributeMap attrs_;
8768
};
8869

70+
class OpKernel {
71+
public:
72+
/**
73+
* KernelContext is the only parameter of Kernel Run function.
74+
* Run will get input/output variables, state such as momentum and
75+
* device resource such as CUDA stream, cublas handle, etc. from
76+
* KernelContext. User should construct it before run the Operator.
77+
*/
78+
class KernelContext {
79+
public:
80+
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
81+
const platform::DeviceContext& device_context)
82+
: op_(*op), scope_(scope), device_context_(device_context) {}
83+
84+
const Variable* Input(int index) const {
85+
return scope_->GetVariable(op_.inputs_[index]);
86+
}
87+
88+
Variable* Output(int index) const {
89+
return scope_->GetVariable(op_.outputs_[index]);
90+
}
91+
92+
const OperatorBase& op_;
93+
const std::shared_ptr<Scope>& scope_;
94+
const platform::DeviceContext& device_context_;
95+
};
96+
97+
virtual void Compute(const KernelContext& context) const = 0;
98+
99+
virtual ~OpKernel() {}
100+
};
101+
89102
class OperatorWithKernel : public OperatorBase {
90103
public:
91-
virtual ~OperatorWithKernel() {}
104+
struct OpKernelKey {
105+
platform::Place place_;
92106

93-
virtual void InferShape(const std::shared_ptr<Scope>& scope) const {}
107+
OpKernelKey() = default;
108+
OpKernelKey(const platform::DeviceContext& dev_ctx) {
109+
place_ = dev_ctx.GetPlace();
110+
}
111+
112+
bool operator==(const OpKernelKey& o) const { return place_ == o.place_; }
113+
};
114+
115+
struct OpKernelHash {
116+
std::hash<bool> hash_;
117+
size_t operator()(const OpKernelKey& key) const {
118+
return hash_(platform::is_gpu_place(key.place_));
119+
}
120+
};
121+
122+
using OpKernelMap =
123+
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
94124

95125
void Run(const std::shared_ptr<Scope>& scope,
96-
const DeviceContext* dev_ctx) const {
97-
OpRunContext op_ctx(this, scope, dev_ctx);
98-
Run(&op_ctx);
126+
const platform::DeviceContext& dev_ctx) const final {
127+
auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
128+
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
99129
}
100130

101-
/// when implement an Op, your should implement this function.
102-
/// this function should be moved to OpKernel later
103-
virtual void Run(const OpRunContext* context) const = 0;
131+
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
132+
AllOpKernels() {
133+
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
134+
return g_all_op_kernels;
135+
};
104136
};
105137

106138
} // namespace framework
107139
} // namespace paddle
140+
141+
#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
142+
struct __op_kernel_register__##type##__ { \
143+
__op_kernel_register__##type##__() { \
144+
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
145+
key.place_ = PlaceType(); \
146+
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
147+
.reset(new KernelType()); \
148+
} \
149+
}; \
150+
static __op_kernel_register__##type##__ __reg_kernel_##type##__

paddle/framework/operator_test.cc

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,15 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace framework {
2121

22-
class OperatorTest : public OperatorWithKernel {
22+
class OperatorTest : public OperatorBase {
2323
public:
24-
void Run(const OpRunContext* ctx) const override {
25-
float scale = ctx->op_->GetAttr<float>("scale");
26-
PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized");
27-
PADDLE_ENFORCE(ctx->Output(0) == nullptr,
28-
"Output(1) should not initialized");
29-
auto output1 = ctx->scope_->CreateVariable("output1");
30-
PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope");
31-
printf("get attr %s = %f\n", "scale", scale);
32-
printf("%s\n", DebugString().c_str());
24+
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
25+
void Run(const std::shared_ptr<Scope>& scope,
26+
const platform::DeviceContext& dev_ctx) const override {
27+
float scale = GetAttr<float>("scale");
28+
ASSERT_NEAR(scale, 3.14, 1e-5);
29+
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
30+
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
3331
}
3432
};
3533

@@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
4947

5048
REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator)
5149

52-
TEST(OperatorBase, DebugString) {
50+
TEST(OperatorBase, all) {
5351
OpDesc op_desc;
5452
op_desc.set_type("test_operator");
55-
std::vector<std::string> inputs = {"IN1", "IN2"};
56-
for (auto& input : inputs) {
57-
op_desc.add_inputs(input);
58-
}
59-
std::vector<std::string> outputs = {"OUT1", "OUT2"};
60-
for (auto& output : outputs) {
61-
op_desc.add_outputs(output);
62-
}
53+
*op_desc.mutable_inputs()->Add() = "IN1";
54+
*op_desc.mutable_outputs()->Add() = "OUT1";
6355
auto attr = op_desc.mutable_attrs()->Add();
6456
attr->set_name("scale");
6557
attr->set_type(paddle::framework::AttrType::FLOAT);
6658
float scale = 3.14;
6759
attr->set_f(scale);
6860

69-
DeviceContext device_context;
61+
platform::CPUDeviceContext device_context;
7062
auto scope = std::make_shared<Scope>();
7163

7264
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
73-
ASSERT_EQ(op->inputs_, inputs);
74-
ASSERT_EQ(op->outputs_, outputs);
7565
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
76-
op->Run(scope, &device_context);
66+
scope->CreateVariable("OUT1");
67+
op->Run(scope, device_context);
68+
std::cout << op->DebugString() << std::endl;
69+
delete op;
7770
}
7871

7972
} // namespace framework

paddle/operators/.clang-format

Lines changed: 0 additions & 5 deletions
This file was deleted.

paddle/operators/CMakeLists.txt

Whitespace-only changes.

paddle/operators/demo_op.h

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)