@@ -20,7 +20,7 @@ limitations under the License. */
2020#include < unordered_map>
2121#include < unordered_set>
2222#include " paddle/framework/attr_checker.h"
23- #include " paddle/framework/grad_op_creator .h"
23+ #include " paddle/framework/grad_op_builder .h"
2424#include " paddle/framework/op_desc.pb.h"
2525#include " paddle/framework/scope.h"
2626
@@ -222,7 +222,7 @@ class OpRegistry {
222222 public:
223223 template <typename OpType, typename ProtoMakerType>
224224 static void RegisterOp (const std::string& op_type) {
225- creators ()[op_type] = [] { return new OpType; };
225+ op_creators ()[op_type] = [] { return new OpType; };
226226 OpAttrChecker& op_checker = op_checkers ()[op_type];
227227 OpProto& op_proto = protos ()[op_type];
228228 auto maker = ProtoMakerType (&op_proto, &op_checker);
@@ -245,17 +245,19 @@ class OpRegistry {
245245 }
246246 }
247247
248- template <typename OpType>
249- static void RegisterGradOp (const std::string& op_type) {
250- grad_creators ()[op_type] = [] { return new OpType; };
248+ template <typename GradOpType>
249+ static void RegisterGradOp (const std::string& op_type,
250+ const std::string& grad_op_type) {
251+ op_creators ()[grad_op_type] = [] { return new GradOpType; };
252+ grad_ops ()[op_type] = grad_op_type;
251253 }
252254
253255 static std::shared_ptr<OperatorBase> CreateOp (const std::string& type,
254256 const VarNameList& inputs,
255257 const VarNameList& outputs,
256258 const AttributeMap& attrs) {
257- auto op_create_it = creators ().find (type);
258- PADDLE_ENFORCE (op_create_it != creators ().end (),
259+ auto op_create_it = op_creators ().find (type);
260+ PADDLE_ENFORCE (op_create_it != op_creators ().end (),
259261 " Operator %s cannot be found." , type);
260262
261263 auto op = op_create_it->second ();
@@ -300,8 +302,8 @@ class OpRegistry {
300302
301303 static std::shared_ptr<OperatorBase> CreateGradOp (
302304 std::shared_ptr<OperatorBase> op) {
303- GradOpCreator creator (op.get ());
304- std::shared_ptr<OperatorBase> grad_op (creator. Create ());
305+ GradOpBuilder builder (op.get ());
306+ std::shared_ptr<OperatorBase> grad_op (builder. Build ());
305307 grad_op->Init ();
306308 return grad_op;
307309 }
@@ -311,9 +313,9 @@ class OpRegistry {
311313 return protos_;
312314 };
313315
314- static std::unordered_map<std::string, OpCreator >& grad_creators () {
315- static std::unordered_map<std::string, OpCreator> grad_creators_ ;
316- return grad_creators_ ;
316+ static std::unordered_map<std::string, std::string >& grad_ops () {
317+ static std::unordered_map<std::string, std::string> grad_ops_ ;
318+ return grad_ops_ ;
317319 }
318320
319321 static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
@@ -322,12 +324,12 @@ class OpRegistry {
322324 return maps_;
323325 }
324326
325- private:
326- static std::unordered_map<std::string, OpCreator>& creators () {
327- static std::unordered_map<std::string, OpCreator> creators_;
328- return creators_;
327+ static std::unordered_map<std::string, OpCreator>& op_creators () {
328+ static std::unordered_map<std::string, OpCreator> op_creators_;
329+ return op_creators_;
329330 }
330331
332+ private:
331333 static std::unordered_map<std::string, OpAttrChecker>& op_checkers () {
332334 static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
333335 return op_checkers_;
@@ -353,11 +355,11 @@ class OpRegisterHelper {
353355 }
354356};
355357
356- template <typename OpType >
358+ template <typename GradOpType >
357359class GradOpRegisterHelper {
358360 public:
359- GradOpRegisterHelper (const char * op_type) {
360- OpRegistry::RegisterGradOp<OpType >(op_type);
361+ GradOpRegisterHelper (const char * op_type, const char * grad_op_type ) {
362+ OpRegistry::RegisterGradOp<GradOpType >(op_type, grad_op_type );
361363 }
362364};
363365
@@ -383,13 +385,16 @@ class GradOpRegisterHelper {
383385/* *
384386 * Macro to Register Gradient Operator.
385387 */
386- #define REGISTER_GRADIENT_OP (__op_type, __op_class ) \
387- STATIC_ASSERT_GLOBAL_NAMESPACE ( \
388- __reg_gradient_op__##__op_type, \
389- " REGISTER_GRADIENT_OP must be in global namespace" ); \
390- static ::paddle::framework::GradOpRegisterHelper<__op_class> \
391- __op_gradient_register_##__op_type##__(#__op_type); \
392- int __op_gradient_register_##__op_type##_handle__() { return 0 ; }
388+ #define REGISTER_GRADIENT_OP (__op_type, __grad_op_type, __grad_op_class ) \
389+ STATIC_ASSERT_GLOBAL_NAMESPACE ( \
390+ __reg_gradient_op__##__op_type##__grad_op_type, \
391+ " REGISTER_GRADIENT_OP must be in global namespace" ); \
392+ static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
393+ __op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
394+ #__grad_op_type); \
395+ int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
396+ return 0 ; \
397+ }
393398
394399/* *
395400 * Macro to Register OperatorKernel.
0 commit comments