Skip to content

Commit 484a2da

Browse files
authored
Merge pull request #3587 from reyoung/feature/extract_op_info_into_op_info.cc
Feature/extract op info into op info.cc
2 parents 7c8acd4 + 760cb6c commit 484a2da

File tree

15 files changed

+219
-125
lines changed

15 files changed

+219
-125
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
1818
proto_library(framework_proto SRCS framework.proto)
1919

2020
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
21-
22-
cc_library(operator SRCS operator.cc DEPS framework_proto device_context tensor scope attribute)
21+
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
22+
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
2323
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
2424

2525
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator)

paddle/framework/backward_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
7272

7373
class FcOp : public operators::NetOp {
7474
public:
75-
FcOp(const std::string &type, const VarNameMap &inputs,
76-
const VarNameMap &outputs, const AttributeMap &attrs)
75+
FcOp(const std::string &type, const VariableNameMap &inputs,
76+
const VariableNameMap &outputs, const AttributeMap &attrs)
7777
: NetOp(type, inputs, outputs, attrs) {
7878
AppendOp(OpRegistry::CreateOp("mul",
7979
{{"X", {Input("X")}}, {"Y", {Input("W")}}},

paddle/framework/grad_op_builder.cc

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ namespace framework {
2020
enum class OpArgType { IN, OUT };
2121

2222
static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
23-
bool is_grad, OperatorBase::VarNameMap* vars) {
23+
bool is_grad, VariableNameMap* vars) {
2424
const auto& src_inout =
2525
src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
2626
auto& dst_inout = *vars;
27-
const OpProto* proto = OpRegistry::op_info_map().at(src_op->Type()).proto_;
27+
auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
2828
const auto& src_arg_list =
29-
src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
29+
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
3030
for (const auto& arg : src_arg_list) {
3131
if (arg.not_in_gradient() && !is_grad) continue;
3232
const std::string src_name = arg.name();
@@ -40,26 +40,18 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
4040
}
4141

4242
OperatorBase* BuildGradOp(const OperatorBase* op) {
43-
auto it = OpRegistry::op_info_map().find(op->Type());
44-
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
45-
"'%s' has not been registered.", op->Type());
46-
PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
47-
op->Type());
48-
std::string grad_op_type = it->second.grad_op_type_;
49-
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
50-
op->Type());
43+
auto& info = OpInfoMap::Instance().Get(op->Type());
44+
PADDLE_ENFORCE(info.HasGradientOp());
5145

52-
OperatorBase::VarNameMap inputs;
53-
OperatorBase::VarNameMap outputs;
46+
VariableNameMap inputs;
47+
VariableNameMap outputs;
5448
TransOpArg(op, OpArgType::IN, false, &inputs); // I
5549
TransOpArg(op, OpArgType::OUT, false, &inputs); // O
5650
TransOpArg(op, OpArgType::OUT, true, &inputs); // OG
5751
TransOpArg(op, OpArgType::IN, true, &outputs); // IG
5852

59-
it = OpRegistry::op_info_map().find(grad_op_type);
60-
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
61-
"'%s' has not been registered.", grad_op_type);
62-
return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs());
53+
auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_);
54+
return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs());
6355
}
6456

6557
} // namespace framework

paddle/framework/op_info.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/op_info.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
20+
static OpInfoMap* g_op_info_map = nullptr;
21+
22+
OpInfoMap& OpInfoMap::Instance() {
23+
if (g_op_info_map == nullptr) {
24+
g_op_info_map = new OpInfoMap();
25+
}
26+
return *g_op_info_map;
27+
}
28+
} // namespace framework
29+
} // namespace paddle

paddle/framework/op_info.h

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <functional>
17+
#include <map>
18+
#include <string>
19+
#include <unordered_map>
20+
21+
#include "paddle/framework/attribute.h"
22+
23+
namespace paddle {
24+
namespace framework {
25+
class OperatorBase;
26+
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
27+
28+
using OpCreator = std::function<OperatorBase*(
29+
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
30+
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
31+
32+
struct OpInfo {
33+
OpCreator creator_;
34+
std::string grad_op_type_;
35+
OpProto* proto_;
36+
OpAttrChecker* checker_;
37+
38+
bool HasOpProtoAndChecker() const {
39+
return proto_ != nullptr && checker_ != nullptr;
40+
}
41+
42+
const OpProto& Proto() const {
43+
PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered");
44+
PADDLE_ENFORCE(proto_->IsInitialized(),
45+
"Operator Proto must be initialized in op info");
46+
return *proto_;
47+
}
48+
49+
const OpAttrChecker& Checker() const {
50+
PADDLE_ENFORCE_NOT_NULL(checker_,
51+
"Operator Checker has not been registered");
52+
return *checker_;
53+
}
54+
55+
const OpCreator& Creator() const {
56+
PADDLE_ENFORCE_NOT_NULL(creator_,
57+
"Operator Creator has not been registered");
58+
return creator_;
59+
}
60+
61+
bool HasGradientOp() const { return !grad_op_type_.empty(); }
62+
};
63+
64+
class OpInfoMap {
65+
public:
66+
static OpInfoMap& Instance();
67+
68+
OpInfoMap(const OpInfoMap& o) = delete;
69+
OpInfoMap(OpInfoMap&& o) = delete;
70+
OpInfoMap& operator=(const OpInfoMap& o) = delete;
71+
OpInfoMap& operator=(OpInfoMap&& o) = delete;
72+
73+
bool Has(const std::string& op_type) const {
74+
return map_.find(op_type) != map_.end();
75+
}
76+
77+
void Insert(const std::string& type, const OpInfo& info) {
78+
PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type);
79+
map_.insert({type, info});
80+
}
81+
82+
const OpInfo& Get(const std::string& type) const {
83+
auto it = map_.find(type);
84+
PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type);
85+
return it->second;
86+
}
87+
88+
template <typename Callback>
89+
void IterAllInfo(Callback callback) {
90+
for (auto& it : map_) {
91+
callback(it.first, it.second);
92+
}
93+
}
94+
95+
private:
96+
OpInfoMap() = default;
97+
std::unordered_map<std::string, const OpInfo> map_;
98+
};
99+
100+
} // namespace framework
101+
} // namespace paddle

paddle/framework/op_registry.cc

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,18 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace framework {
2121

22-
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
23-
const VarNameMap& inputs,
24-
const VarNameMap& outputs,
25-
AttributeMap attrs) {
26-
auto it = op_info_map().find(type);
27-
PADDLE_ENFORCE(it != op_info_map().end(),
28-
"Operator '%s' has not been registered.", type);
29-
it->second.checker_->Check(attrs);
30-
auto op = it->second.creator_(type, inputs, outputs, attrs);
22+
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
23+
const std::string& type, const VariableNameMap& inputs,
24+
const VariableNameMap& outputs, AttributeMap attrs) {
25+
auto& info = OpInfoMap::Instance().Get(type);
26+
info.Checker().Check(attrs);
27+
auto op = info.Creator()(type, inputs, outputs, attrs);
3128
return std::unique_ptr<OperatorBase>(op);
3229
}
3330

34-
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
35-
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
36-
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
37-
AttributeMap attrs;
38-
for (auto& attr : op_desc.attrs()) {
39-
attrs[attr.name()] = GetAttrValue(attr);
40-
}
41-
42-
return CreateOp(op_desc.type(), inputs, outputs, attrs);
43-
}
44-
45-
OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap(
31+
static VariableNameMap ConvertOpDescVarsToVarNameMap(
4632
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars) {
47-
VarNameMap ret_val;
33+
VariableNameMap ret_val;
4834
for (auto& var : op_desc_vars) {
4935
auto& var_names = ret_val[var.parameter()];
5036
auto& var_names_in_proto = var.arguments();
@@ -55,6 +41,17 @@ OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap(
5541
return ret_val;
5642
}
5743

44+
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
45+
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
46+
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
47+
AttributeMap attrs;
48+
for (auto& attr : op_desc.attrs()) {
49+
attrs[attr.name()] = GetAttrValue(attr);
50+
}
51+
52+
return CreateOp(op_desc.type(), inputs, outputs, attrs);
53+
}
54+
5855
std::unique_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
5956
PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops");
6057
return std::unique_ptr<OperatorBase>(BuildGradOp(&op));

paddle/framework/op_registry.h

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,35 +23,24 @@ limitations under the License. */
2323
#include "paddle/framework/attribute.h"
2424
#include "paddle/framework/framework.pb.h"
2525
#include "paddle/framework/grad_op_builder.h"
26+
#include "paddle/framework/op_info.h"
2627
#include "paddle/framework/operator.h"
2728
#include "paddle/framework/scope.h"
2829

2930
namespace paddle {
3031
namespace framework {
3132

3233
class OpRegistry {
33-
using VarNameMap = OperatorBase::VarNameMap;
34-
using OpCreator = std::function<OperatorBase*(
35-
const std::string& /*type*/, const VarNameMap& /*inputs*/,
36-
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
37-
3834
public:
39-
struct OpInfo {
40-
OpCreator creator_;
41-
std::string grad_op_type_;
42-
OpProto* proto_;
43-
OpAttrChecker* checker_;
44-
};
45-
4635
template <typename OpType, typename ProtoMakerType, typename GradOpType>
4736
static void RegisterOp(const std::string& op_type,
4837
const std::string& grad_op_type) {
49-
PADDLE_ENFORCE(op_info_map().count(op_type) == 0,
38+
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
5039
"'%s' is registered more than once.", op_type);
5140
OpInfo op_info;
52-
op_info.creator_ = [](const std::string& type, const VarNameMap& inputs,
53-
const VarNameMap& outputs,
54-
const AttributeMap& attrs) {
41+
op_info.creator_ = [](
42+
const std::string& type, const VariableNameMap& inputs,
43+
const VariableNameMap& outputs, const AttributeMap& attrs) {
5544
return new OpType(type, inputs, outputs, attrs);
5645
};
5746
op_info.grad_op_type_ = grad_op_type;
@@ -70,29 +59,21 @@ class OpRegistry {
7059
op_info.proto_ = nullptr;
7160
op_info.checker_ = nullptr;
7261
}
73-
op_info_map().insert(std::make_pair(op_type, op_info));
62+
OpInfoMap::Instance().Insert(op_type, op_info);
7463
// register gradient op
7564
if (!grad_op_type.empty()) {
7665
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
7766
}
7867
}
7968

8069
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
81-
const VarNameMap& inputs,
82-
const VarNameMap& outputs,
70+
const VariableNameMap& inputs,
71+
const VariableNameMap& outputs,
8372
AttributeMap attrs);
8473

8574
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
8675

87-
static VarNameMap ConvertOpDescVarsToVarNameMap(
88-
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars);
89-
9076
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
91-
92-
static std::unordered_map<std::string, const OpInfo>& op_info_map() {
93-
static std::unordered_map<std::string, const OpInfo> op_info_map_;
94-
return op_info_map_;
95-
}
9677
};
9778

9879
class Registrar {

paddle/framework/operator.cc

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ void OperatorBase::Rename(const std::string& old_name,
115115
}
116116

117117
OperatorBase::OperatorBase(const std::string& type,
118-
const OperatorBase::VarNameMap& inputs,
119-
const OperatorBase::VarNameMap& outputs,
118+
const VariableNameMap& inputs,
119+
const VariableNameMap& outputs,
120120
const AttributeMap& attrs)
121121
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
122122
static std::atomic<size_t> gUniqId(0UL);
@@ -141,18 +141,10 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
141141
}
142142
return ret_val;
143143
}
144-
auto it = OpRegistry::op_info_map().find(type_);
145-
PADDLE_ENFORCE(
146-
it != OpRegistry::op_info_map().end(),
147-
"Operator %s not registered, cannot figure out intermediate outputs",
148-
type_);
149-
PADDLE_ENFORCE(
150-
it->second.proto_ != nullptr,
151-
"Operator %s has no OpProto, cannot figure out intermediate outputs",
152-
type_);
144+
auto& info = OpInfoMap::Instance().Get(Type());
153145

154146
// get all OpProto::Var for outputs
155-
for (auto& o : it->second.proto_->outputs()) {
147+
for (auto& o : info.Proto().outputs()) {
156148
// ignore all intermediate output
157149
if (o.intermediate()) continue;
158150
auto out = outputs_.find(o.name());

0 commit comments

Comments
 (0)