Skip to content

Commit 72b5bd9

Browse files
authored
Merge pull request #3036 from Canpio/dev_update_backward
update gradient operator registry mechanism
2 parents 91689b6 + e8a0e92 commit 72b5bd9

File tree

10 files changed

+55
-49
lines changed

10 files changed

+55
-49
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
1919
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
2020
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
2121

22-
cc_library(grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator)
23-
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator)
22+
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator)
23+
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder)
2424
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
25-
cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op)
25+
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
2626

2727
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
2828
# Generate an empty __init__.py to make framework_py_proto as a valid python module.

paddle/framework/grad_op_creator.cc renamed to paddle/framework/grad_op_builder.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/framework/grad_op_creator.h"
15+
#include "paddle/framework/grad_op_builder.h"
1616
#include "paddle/framework/op_registry.h"
1717

1818
namespace paddle {
1919
namespace framework {
2020

21-
OperatorBase* GradOpCreator::Create() {
21+
OperatorBase* GradOpBuilder::Build() {
2222
BuildOpInOutArgList();
23-
OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)();
23+
std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_);
24+
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
25+
grad_op->type_ = grad_op_type;
2426
CompleteGradOp(grad_op);
2527
return grad_op;
2628
}
2729

28-
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
30+
OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var,
2931
const VarIndexMap& var_map,
3032
const std::vector<int>& format,
3133
InOutType type) {
@@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
3638
end_idx);
3739
}
3840

39-
void GradOpCreator::BuildOpInOutArgList() {
41+
void GradOpBuilder::BuildOpInOutArgList() {
4042
const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
4143
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
4244
const std::vector<int>& in_format =
@@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() {
5759
}
5860
}
5961

60-
void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
62+
void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
6163
std::vector<std::string>& in_out,
6264
std::vector<int>& format,
6365
VarIndexMap* varmap, int& idx,
@@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
8082
format.push_back(in_out.size());
8183
}
8284

83-
void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
84-
grad_op->type_ = op_->type_ + "@GRAD"; // not necessary
85+
void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const {
8586
grad_op->attrs_ = op_->attrs_;
8687
grad_op->attrs_.erase("input_format");
8788
grad_op->attrs_.erase("output_format");

paddle/framework/grad_op_creator.h renamed to paddle/framework/grad_op_builder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ struct OpInOutArg {
2525
size_t end_idx_;
2626
};
2727

28-
class GradOpCreator {
28+
class GradOpBuilder {
2929
using VarIndexMap = std::unordered_map<std::string, int>;
3030

3131
public:
32-
GradOpCreator(const OperatorBase* op) : op_(op) {}
33-
OperatorBase* Create();
32+
GradOpBuilder(const OperatorBase* op) : op_(op) {}
33+
OperatorBase* Build();
3434

3535
private:
3636
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,

paddle/framework/grad_op_creator_test.cc renamed to paddle/framework/grad_op_builder_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "paddle/framework/grad_op_creator.h"
1+
#include "paddle/framework/grad_op_builder.h"
22
#include <gtest/gtest.h>
33
#include "paddle/framework/op_registry.h"
44
#include "paddle/framework/operator.h"
@@ -8,7 +8,7 @@ USE_OP(add_two);
88
namespace paddle {
99
namespace framework {
1010

11-
TEST(GradOpCreator, AddTwo) {
11+
TEST(GradOpBuilder, AddTwo) {
1212
std::shared_ptr<OperatorBase> add_op(
1313
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
1414
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);

paddle/framework/op_registry.h

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License. */
2020
#include <unordered_map>
2121
#include <unordered_set>
2222
#include "paddle/framework/attr_checker.h"
23-
#include "paddle/framework/grad_op_creator.h"
23+
#include "paddle/framework/grad_op_builder.h"
2424
#include "paddle/framework/op_desc.pb.h"
2525
#include "paddle/framework/scope.h"
2626

@@ -222,7 +222,7 @@ class OpRegistry {
222222
public:
223223
template <typename OpType, typename ProtoMakerType>
224224
static void RegisterOp(const std::string& op_type) {
225-
creators()[op_type] = [] { return new OpType; };
225+
op_creators()[op_type] = [] { return new OpType; };
226226
OpAttrChecker& op_checker = op_checkers()[op_type];
227227
OpProto& op_proto = protos()[op_type];
228228
auto maker = ProtoMakerType(&op_proto, &op_checker);
@@ -245,17 +245,19 @@ class OpRegistry {
245245
}
246246
}
247247

248-
template <typename OpType>
249-
static void RegisterGradOp(const std::string& op_type) {
250-
grad_creators()[op_type] = [] { return new OpType; };
248+
template <typename GradOpType>
249+
static void RegisterGradOp(const std::string& op_type,
250+
const std::string& grad_op_type) {
251+
op_creators()[grad_op_type] = [] { return new GradOpType; };
252+
grad_ops()[op_type] = grad_op_type;
251253
}
252254

253255
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
254256
const VarNameList& inputs,
255257
const VarNameList& outputs,
256258
const AttributeMap& attrs) {
257-
auto op_create_it = creators().find(type);
258-
PADDLE_ENFORCE(op_create_it != creators().end(),
259+
auto op_create_it = op_creators().find(type);
260+
PADDLE_ENFORCE(op_create_it != op_creators().end(),
259261
"Operator %s cannot be found.", type);
260262

261263
auto op = op_create_it->second();
@@ -300,8 +302,8 @@ class OpRegistry {
300302

301303
static std::shared_ptr<OperatorBase> CreateGradOp(
302304
std::shared_ptr<OperatorBase> op) {
303-
GradOpCreator creator(op.get());
304-
std::shared_ptr<OperatorBase> grad_op(creator.Create());
305+
GradOpBuilder builder(op.get());
306+
std::shared_ptr<OperatorBase> grad_op(builder.Build());
305307
grad_op->Init();
306308
return grad_op;
307309
}
@@ -311,9 +313,9 @@ class OpRegistry {
311313
return protos_;
312314
};
313315

314-
static std::unordered_map<std::string, OpCreator>& grad_creators() {
315-
static std::unordered_map<std::string, OpCreator> grad_creators_;
316-
return grad_creators_;
316+
static std::unordered_map<std::string, std::string>& grad_ops() {
317+
static std::unordered_map<std::string, std::string> grad_ops_;
318+
return grad_ops_;
317319
}
318320

319321
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
@@ -322,12 +324,12 @@ class OpRegistry {
322324
return maps_;
323325
}
324326

325-
private:
326-
static std::unordered_map<std::string, OpCreator>& creators() {
327-
static std::unordered_map<std::string, OpCreator> creators_;
328-
return creators_;
327+
static std::unordered_map<std::string, OpCreator>& op_creators() {
328+
static std::unordered_map<std::string, OpCreator> op_creators_;
329+
return op_creators_;
329330
}
330331

332+
private:
331333
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
332334
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
333335
return op_checkers_;
@@ -353,11 +355,11 @@ class OpRegisterHelper {
353355
}
354356
};
355357

356-
template <typename OpType>
358+
template <typename GradOpType>
357359
class GradOpRegisterHelper {
358360
public:
359-
GradOpRegisterHelper(const char* op_type) {
360-
OpRegistry::RegisterGradOp<OpType>(op_type);
361+
GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
362+
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
361363
}
362364
};
363365

@@ -383,13 +385,16 @@ class GradOpRegisterHelper {
383385
/**
384386
* Macro to Register Gradient Operator.
385387
*/
386-
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
387-
STATIC_ASSERT_GLOBAL_NAMESPACE( \
388-
__reg_gradient_op__##__op_type, \
389-
"REGISTER_GRADIENT_OP must be in global namespace"); \
390-
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
391-
__op_gradient_register_##__op_type##__(#__op_type); \
392-
int __op_gradient_register_##__op_type##_handle__() { return 0; }
388+
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
389+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
390+
__reg_gradient_op__##__op_type##__grad_op_type, \
391+
"REGISTER_GRADIENT_OP must be in global namespace"); \
392+
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
393+
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
394+
#__grad_op_type); \
395+
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
396+
return 0; \
397+
}
393398

394399
/**
395400
* Macro to Register OperatorKernel.

paddle/operators/add_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,6 @@ class AddOpGrad : public framework::OperatorWithKernel {
6565
} // namespace paddle
6666

6767
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
68-
REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad);
68+
REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad);
6969
REGISTER_OP_CPU_KERNEL(
7070
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);

paddle/operators/add_op_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ TEST(AddOp, GetOpProto) {
2222
auto& protos = paddle::framework::OpRegistry::protos();
2323
auto it = protos.find("add_two");
2424
ASSERT_NE(it, protos.end());
25-
auto& grad_creators = paddle::framework::OpRegistry::grad_creators();
26-
auto it1 = grad_creators.find("add_two");
27-
ASSERT_NE(it1, grad_creators.end());
25+
auto& op_creators = paddle::framework::OpRegistry::op_creators();
26+
auto it1 = op_creators.find("add_two_grad");
27+
ASSERT_NE(it1, op_creators.end());
2828
}

paddle/operators/mul_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
6767
} // namespace paddle
6868

6969
REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker);
70-
REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad);
70+
REGISTER_GRADIENT_OP(mul, mul_grad, paddle::operators::MulOpGrad);
7171

7272
REGISTER_OP_CPU_KERNEL(
7373
mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>);

paddle/operators/sigmoid_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
5656
REGISTER_OP(sigmoid,
5757
paddle::operators::SigmoidOp,
5858
paddle::operators::SigmoidOpMaker);
59-
REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad);
59+
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, paddle::operators::SigmoidOpGrad);
6060

6161
REGISTER_OP_CPU_KERNEL(
6262
sigmoid,

paddle/operators/softmax_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
5959
namespace ops = paddle::operators;
6060

6161
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
62-
REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad);
62+
REGISTER_GRADIENT_OP(softmax, softmax_grad, paddle::operators::SoftmaxOpGrad);
6363
REGISTER_OP_CPU_KERNEL(softmax,
6464
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);

0 commit comments

Comments
 (0)