@@ -307,22 +307,45 @@ class OpRegistry {
307307 }
308308};
309309
310+ class Registrar {
311+ public:
312+ // In our design, various kinds of classes, e.g., operators and kernels, have
313+ // their corresponding registry and registrar. The action of registration is
314+ // in the constructor of a global registrar variable, which, however, are not
315+ // used in the code that calls package framework, and would be removed from
316+ // the generated binary file by the linker. To avoid such removal, we add
317+ // Touch to all registrar classes and make USE_OP macros to call this
318+ // method. So, as long as the callee code calls USE_OP, the global
319+ // registrar variable won't be removed by the linker.
320+ void Touch () {}
321+ };
322+
310323template <typename OpType, typename ProtoMakerType>
311- class OpRegisterHelper {
324+ class OpRegistrar : public Registrar {
312325 public:
313- explicit OpRegisterHelper (const char * op_type) {
326+ explicit OpRegistrar (const char * op_type) {
314327 OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
315328 }
316329};
317330
318331template <typename GradOpType>
319- class GradOpRegisterHelper {
332+ class GradOpRegistrar : public Registrar {
320333 public:
321- GradOpRegisterHelper (const char * op_type, const char * grad_op_type) {
334+ GradOpRegistrar (const char * op_type, const char * grad_op_type) {
322335 OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
323336 }
324337};
325338
339+ template <typename PlaceType, typename KernelType>
340+ class OpKernelRegistrar : public Registrar {
341+ public:
342+ explicit OpKernelRegistrar (const char * op_type) {
343+ OperatorWithKernel::OpKernelKey key;
344+ key.place_ = PlaceType ();
345+ OperatorWithKernel::AllOpKernels ()[op_type][key].reset (new KernelType);
346+ }
347+ };
348+
326349/* *
327350 * check if MACRO is used in GLOBAL NAMESPACE.
328351 */
@@ -333,97 +356,121 @@ class GradOpRegisterHelper {
333356 msg)
334357
335358/* *
336- * Macro to Register Operator.
359+ * Macro to register Operator.
337360 */
338- #define REGISTER_OP (__op_type, __op_class, __op_maker_class ) \
339- STATIC_ASSERT_GLOBAL_NAMESPACE (__reg_op__##__op_type, \
340- " REGISTER_OP must be in global namespace" ); \
341- static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
342- __op_register_##__op_type##__(#__op_type); \
343- int __op_register_##__op_type##_handle__() { return 0 ; }
361+ #define REGISTER_OP (op_type, op_class, op_maker_class ) \
362+ STATIC_ASSERT_GLOBAL_NAMESPACE ( \
363+ __reg_op__##op_type, " REGISTER_OP must be called in global namespace" ); \
364+ static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
365+ __op_registrar_##op_type##__(#op_type); \
366+ int TouchOpRegistrar_##op_type() { \
367+ __op_registrar_##op_type##__.Touch (); \
368+ return 0 ; \
369+ }
344370
345371/* *
346- * Macro to Register Gradient Operator.
372+ * Macro to register Gradient Operator.
347373 */
348- #define REGISTER_GRADIENT_OP (__op_type, __grad_op_type, __grad_op_class ) \
349- STATIC_ASSERT_GLOBAL_NAMESPACE ( \
350- __reg_gradient_op__##__op_type##__grad_op_type, \
351- " REGISTER_GRADIENT_OP must be in global namespace" ); \
352- static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
353- __op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
354- #__grad_op_type); \
355- int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
356- return 0 ; \
374+ #define REGISTER_GRADIENT_OP (op_type, grad_op_type, grad_op_class ) \
375+ STATIC_ASSERT_GLOBAL_NAMESPACE ( \
376+ __reg_gradient_op__##op_type##_##grad_op_type, \
377+ " REGISTER_GRADIENT_OP must be called in global namespace" ); \
378+ static ::paddle::framework::GradOpRegistrar<grad_op_class> \
379+ __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
380+ #grad_op_type); \
381+ int TouchOpGradientRegistrar_##op_type() { \
382+ __op_gradient_registrar_##op_type##_##grad_op_type##__.Touch (); \
383+ return 0 ; \
357384 }
358385
359386/* *
360- * Macro to Forbid user register Gradient Operator .
387+ * Macro to register OperatorKernel .
361388 */
362- #define NO_GRADIENT (__op_type ) \
363- STATIC_ASSERT_GLOBAL_NAMESPACE ( \
364- __reg_gradient_op__##__op_type##__op_type##_grad, \
365- " NO_GRADIENT must be in global namespace" )
389+ #define REGISTER_OP_KERNEL (op_type, DEVICE_TYPE, place_class, ...) \
390+ STATIC_ASSERT_GLOBAL_NAMESPACE ( \
391+ __reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
392+ " REGISTER_OP_KERNEL must be called in global namespace" ); \
393+ static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
394+ __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
395+ int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
396+ __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch (); \
397+ return 0 ; \
398+ }
366399
367400/* *
368- * Macro to Register OperatorKernel .
401+ * Macro to Forbid user register Gradient Operator .
369402 */
370- #define REGISTER_OP_KERNEL (type, DEVICE_TYPE, PlaceType, ...) \
371- STATIC_ASSERT_GLOBAL_NAMESPACE ( \
372- __reg_op_kernel_##type##_##DEVICE_TYPE##__, \
373- " REGISTER_OP_KERNEL must be in global namespace" ); \
374- struct __op_kernel_register__ ##type##__##DEVICE_TYPE##__ { \
375- __op_kernel_register__##type##__##DEVICE_TYPE##__() { \
376- ::paddle::framework::OperatorWithKernel::OpKernelKey key; \
377- key.place_ = PlaceType (); \
378- ::paddle::framework::OperatorWithKernel::AllOpKernels ()[#type][key] \
379- .reset (new __VA_ARGS__ ()); \
380- } \
381- }; \
382- static __op_kernel_register__##type##__##DEVICE_TYPE##__ \
383- __reg_kernel_##type##__##DEVICE_TYPE##__; \
384- int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0 ; }
385-
386- // (type, KernelType)
387- #define REGISTER_OP_GPU_KERNEL (type, ...) \
388- REGISTER_OP_KERNEL (type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
389-
390- // (type, KernelType)
391- #define REGISTER_OP_CPU_KERNEL (type, ...) \
392- REGISTER_OP_KERNEL (type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
403+ #define NO_GRADIENT (op_type ) \
404+ STATIC_ASSERT_GLOBAL_NAMESPACE ( \
405+ __reg_gradient_op__##op_type##_##op_type##_grad, \
406+ " NO_GRADIENT must be called in global namespace" )
407+
408+ #define REGISTER_OP_GPU_KERNEL (op_type, ...) \
409+ REGISTER_OP_KERNEL (op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
410+
411+ #define REGISTER_OP_CPU_KERNEL (op_type, ...) \
412+ REGISTER_OP_KERNEL (op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
393413
394414/* *
395415 * Macro to mark what Operator and Kernel we will use and tell the compiler to
396416 * link them into target.
397417 */
398- #define USE_OP_WITHOUT_KERNEL (op_type ) \
399- STATIC_ASSERT_GLOBAL_NAMESPACE ( \
400- __use_op_without_kernel_##op_type, \
401- " USE_OP_WITHOUT_KERNEL must be in global namespace" ); \
402- extern int __op_register_##op_type##_handle__(); \
403- static int __use_op_ptr_##op_type##_without_kernel__ \
404- __attribute__ ((unused)) = __op_register_##op_type##_handle__()
405-
406- #define USE_OP_KERNEL (op_type, DEVICE_TYPE ) \
407- STATIC_ASSERT_GLOBAL_NAMESPACE ( \
408- __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
409- " USE_OP_KERNEL must be in global namespace" ); \
410- extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \
411- static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \
412- __attribute__ ((unused)) = \
413- __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()
414-
415- // use Operator with only cpu kernel.
416- #define USE_OP_CPU (op_type ) \
417- USE_OP_WITHOUT_KERNEL (op_type); \
418- USE_OP_KERNEL (op_type, CPU)
418+ #define USE_OP_ITSELF (op_type ) \
419+ STATIC_ASSERT_GLOBAL_NAMESPACE ( \
420+ __use_op_itself_##op_type, \
421+ " USE_OP_ITSELF must be called in global namespace" ); \
422+ extern int TouchOpRegistrar_##op_type(); \
423+ static int use_op_itself_##op_type##_ __attribute__ ((unused)) = \
424+ TouchOpRegistrar_##op_type()
425+
426+ // TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use
427+ // `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
428+ // be compiled. `NO_GRAD` should be removed after all gradient ops are
429+ // compeleted.
430+ #define NO_GRAD
431+ #ifndef NO_GRAD
432+ #define USE_OP_GRADIENT (op_type ) \
433+ STATIC_ASSERT_GLOBAL_NAMESPACE ( \
434+ __use_op_gradient_##op_type, \
435+ " USE_OP_GRADIENT must be called in global namespace" ); \
436+ extern int TouchOpGradientRegistrar_##op_type(); \
437+ static int use_op_gradient_##op_type##_ __attribute__ ((unused)) = \
438+ TouchOpGradientRegistrar_##op_type()
439+ #else
440+ #define USE_OP_GRADIENT (op_type )
441+ #endif
442+
443+ #define USE_OP_DEVICE_KERNEL (op_type, DEVICE_TYPE ) \
444+ STATIC_ASSERT_GLOBAL_NAMESPACE ( \
445+ __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
446+ " USE_OP_DEVICE_KERNEL must be in global namespace" ); \
447+ extern int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE(); \
448+ static int use_op_kernel_##op_type##_##DEVICE_TYPE##_ \
449+ __attribute__ ((unused)) = \
450+ TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE()
451+
452+ // TODO(fengjiayi): The following macros seems ugly, do we have better method?
419453
420454#ifdef PADDLE_ONLY_CPU
421- #define USE_OP (op_type ) USE_OP_CPU (op_type)
455+ #define USE_OP_KERNEL (op_type ) USE_OP_DEVICE_KERNEL (op_type, CPU )
422456#else
423- #define USE_OP (op_type ) \
424- USE_OP_CPU (op_type); \
425- USE_OP_KERNEL (op_type, GPU)
457+ #define USE_OP_KERNEL (op_type ) \
458+ USE_OP_DEVICE_KERNEL (op_type, CPU); \
459+ USE_OP_DEVICE_KERNEL (op_type, GPU)
426460#endif
427461
462+ #define USE_NO_GRAD_OP (op_type ) \
463+ USE_OP_ITSELF (op_type); \
464+ USE_OP_KERNEL (op_type)
465+
466+ #define USE_CPU_OP (op_type ) \
467+ USE_OP_ITSELF (op_type); \
468+ USE_OP_DEVICE_KERNEL (op_type, CPU); \
469+ USE_OP_GRADIENT (op_type)
470+
471+ #define USE_OP (op_type ) \
472+ USE_NO_GRAD_OP (op_type); \
473+ USE_OP_GRADIENT (op_type)
474+
428475} // namespace framework
429476} // namespace paddle
0 commit comments