Skip to content

Commit 19ab1dc

Browse files
authored
Merge pull request #3373 from Canpio/refactor_registry_macro
Refactorize registry macro
2 parents 5d4126a + 580445a commit 19ab1dc

File tree

3 files changed

+126
-79
lines changed

3 files changed

+126
-79
lines changed

paddle/framework/op_registry.h

Lines changed: 122 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -307,22 +307,45 @@ class OpRegistry {
307307
}
308308
};
309309

310+
class Registrar {
311+
public:
312+
// In our design, various kinds of classes, e.g., operators and kernels, have
313+
// their corresponding registry and registrar. The action of registration is
314+
// in the constructor of a global registrar variable, which, however, are not
315+
// used in the code that calls package framework, and would be removed from
316+
// the generated binary file by the linker. To avoid such removal, we add
317+
// Touch to all registrar classes and make USE_OP macros to call this
318+
// method. So, as long as the callee code calls USE_OP, the global
319+
// registrar variable won't be removed by the linker.
320+
void Touch() {}
321+
};
322+
310323
template <typename OpType, typename ProtoMakerType>
311-
class OpRegisterHelper {
324+
class OpRegistrar : public Registrar {
312325
public:
313-
explicit OpRegisterHelper(const char* op_type) {
326+
explicit OpRegistrar(const char* op_type) {
314327
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
315328
}
316329
};
317330

318331
template <typename GradOpType>
319-
class GradOpRegisterHelper {
332+
class GradOpRegistrar : public Registrar {
320333
public:
321-
GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
334+
GradOpRegistrar(const char* op_type, const char* grad_op_type) {
322335
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
323336
}
324337
};
325338

339+
template <typename PlaceType, typename KernelType>
340+
class OpKernelRegistrar : public Registrar {
341+
public:
342+
explicit OpKernelRegistrar(const char* op_type) {
343+
OperatorWithKernel::OpKernelKey key;
344+
key.place_ = PlaceType();
345+
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType);
346+
}
347+
};
348+
326349
/**
327350
* check if MACRO is used in GLOBAL NAMESPACE.
328351
*/
@@ -333,97 +356,121 @@ class GradOpRegisterHelper {
333356
msg)
334357

335358
/**
336-
* Macro to Register Operator.
359+
* Macro to register Operator.
337360
*/
338-
#define REGISTER_OP(__op_type, __op_class, __op_maker_class) \
339-
STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \
340-
"REGISTER_OP must be in global namespace"); \
341-
static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
342-
__op_register_##__op_type##__(#__op_type); \
343-
int __op_register_##__op_type##_handle__() { return 0; }
361+
#define REGISTER_OP(op_type, op_class, op_maker_class) \
362+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
363+
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
364+
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
365+
__op_registrar_##op_type##__(#op_type); \
366+
int TouchOpRegistrar_##op_type() { \
367+
__op_registrar_##op_type##__.Touch(); \
368+
return 0; \
369+
}
344370

345371
/**
346-
* Macro to Register Gradient Operator.
372+
* Macro to register Gradient Operator.
347373
*/
348-
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
349-
STATIC_ASSERT_GLOBAL_NAMESPACE( \
350-
__reg_gradient_op__##__op_type##__grad_op_type, \
351-
"REGISTER_GRADIENT_OP must be in global namespace"); \
352-
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
353-
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
354-
#__grad_op_type); \
355-
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
356-
return 0; \
374+
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
375+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
376+
__reg_gradient_op__##op_type##_##grad_op_type, \
377+
"REGISTER_GRADIENT_OP must be called in global namespace"); \
378+
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
379+
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
380+
#grad_op_type); \
381+
int TouchOpGradientRegistrar_##op_type() { \
382+
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
383+
return 0; \
357384
}
358385

359386
/**
360-
* Macro to Forbid user register Gradient Operator.
387+
* Macro to register OperatorKernel.
361388
*/
362-
#define NO_GRADIENT(__op_type) \
363-
STATIC_ASSERT_GLOBAL_NAMESPACE( \
364-
__reg_gradient_op__##__op_type##__op_type##_grad, \
365-
"NO_GRADIENT must be in global namespace")
389+
#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, ...) \
390+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
391+
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
392+
"REGISTER_OP_KERNEL must be called in global namespace"); \
393+
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
394+
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
395+
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
396+
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
397+
return 0; \
398+
}
366399

367400
/**
368-
* Macro to Register OperatorKernel.
401+
* Macro to Forbid user register Gradient Operator.
369402
*/
370-
#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \
371-
STATIC_ASSERT_GLOBAL_NAMESPACE( \
372-
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \
373-
"REGISTER_OP_KERNEL must be in global namespace"); \
374-
struct __op_kernel_register__##type##__##DEVICE_TYPE##__ { \
375-
__op_kernel_register__##type##__##DEVICE_TYPE##__() { \
376-
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
377-
key.place_ = PlaceType(); \
378-
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
379-
.reset(new __VA_ARGS__()); \
380-
} \
381-
}; \
382-
static __op_kernel_register__##type##__##DEVICE_TYPE##__ \
383-
__reg_kernel_##type##__##DEVICE_TYPE##__; \
384-
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
385-
386-
// (type, KernelType)
387-
#define REGISTER_OP_GPU_KERNEL(type, ...) \
388-
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
389-
390-
// (type, KernelType)
391-
#define REGISTER_OP_CPU_KERNEL(type, ...) \
392-
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
403+
#define NO_GRADIENT(op_type) \
404+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
405+
__reg_gradient_op__##op_type##_##op_type##_grad, \
406+
"NO_GRADIENT must be called in global namespace")
407+
408+
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \
409+
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
410+
411+
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
412+
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
393413

394414
/**
395415
* Macro to mark what Operator and Kernel we will use and tell the compiler to
396416
* link them into target.
397417
*/
398-
#define USE_OP_WITHOUT_KERNEL(op_type) \
399-
STATIC_ASSERT_GLOBAL_NAMESPACE( \
400-
__use_op_without_kernel_##op_type, \
401-
"USE_OP_WITHOUT_KERNEL must be in global namespace"); \
402-
extern int __op_register_##op_type##_handle__(); \
403-
static int __use_op_ptr_##op_type##_without_kernel__ \
404-
__attribute__((unused)) = __op_register_##op_type##_handle__()
405-
406-
#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \
407-
STATIC_ASSERT_GLOBAL_NAMESPACE( \
408-
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
409-
"USE_OP_KERNEL must be in global namespace"); \
410-
extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \
411-
static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \
412-
__attribute__((unused)) = \
413-
__op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()
414-
415-
// use Operator with only cpu kernel.
416-
#define USE_OP_CPU(op_type) \
417-
USE_OP_WITHOUT_KERNEL(op_type); \
418-
USE_OP_KERNEL(op_type, CPU)
418+
#define USE_OP_ITSELF(op_type) \
419+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
420+
__use_op_itself_##op_type, \
421+
"USE_OP_ITSELF must be called in global namespace"); \
422+
extern int TouchOpRegistrar_##op_type(); \
423+
static int use_op_itself_##op_type##_ __attribute__((unused)) = \
424+
TouchOpRegistrar_##op_type()
425+
426+
// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use
427+
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
428+
// be compiled. `NO_GRAD` should be removed after all gradient ops are
429+
// compeleted.
430+
#define NO_GRAD
431+
#ifndef NO_GRAD
432+
#define USE_OP_GRADIENT(op_type) \
433+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
434+
__use_op_gradient_##op_type, \
435+
"USE_OP_GRADIENT must be called in global namespace"); \
436+
extern int TouchOpGradientRegistrar_##op_type(); \
437+
static int use_op_gradient_##op_type##_ __attribute__((unused)) = \
438+
TouchOpGradientRegistrar_##op_type()
439+
#else
440+
#define USE_OP_GRADIENT(op_type)
441+
#endif
442+
443+
#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \
444+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
445+
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
446+
"USE_OP_DEVICE_KERNEL must be in global namespace"); \
447+
extern int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE(); \
448+
static int use_op_kernel_##op_type##_##DEVICE_TYPE##_ \
449+
__attribute__((unused)) = \
450+
TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE()
451+
452+
// TODO(fengjiayi): The following macros seems ugly, do we have better method?
419453

420454
#ifdef PADDLE_ONLY_CPU
421-
#define USE_OP(op_type) USE_OP_CPU(op_type)
455+
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
422456
#else
423-
#define USE_OP(op_type) \
424-
USE_OP_CPU(op_type); \
425-
USE_OP_KERNEL(op_type, GPU)
457+
#define USE_OP_KERNEL(op_type) \
458+
USE_OP_DEVICE_KERNEL(op_type, CPU); \
459+
USE_OP_DEVICE_KERNEL(op_type, GPU)
426460
#endif
427461

462+
#define USE_NO_GRAD_OP(op_type) \
463+
USE_OP_ITSELF(op_type); \
464+
USE_OP_KERNEL(op_type)
465+
466+
#define USE_CPU_OP(op_type) \
467+
USE_OP_ITSELF(op_type); \
468+
USE_OP_DEVICE_KERNEL(op_type, CPU); \
469+
USE_OP_GRADIENT(op_type)
470+
471+
#define USE_OP(op_type) \
472+
USE_NO_GRAD_OP(op_type); \
473+
USE_OP_GRADIENT(op_type)
474+
428475
} // namespace framework
429476
} // namespace paddle

paddle/framework/pybind.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ limitations under the License. */
3030
namespace py = pybind11;
3131

3232
USE_OP(add_two);
33-
USE_OP_CPU(onehot_cross_entropy);
34-
USE_OP(sgd);
33+
USE_CPU_OP(onehot_cross_entropy);
34+
USE_NO_GRAD_OP(sgd);
3535
USE_OP(mul);
3636
USE_OP(mean);
3737
USE_OP(sigmoid);
3838
USE_OP(softmax);
3939
USE_OP(rowwise_add);
4040
USE_OP(fill_zeros_like);
41-
USE_OP_WITHOUT_KERNEL(recurrent_op);
41+
USE_OP_ITSELF(recurrent_op);
4242
USE_OP(gaussian_random);
4343
USE_OP(uniform_random);
4444

paddle/operators/recurrent_op_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,4 +395,4 @@ TEST(RecurrentOp, LinkMemories) {
395395

396396
USE_OP(add_two);
397397
USE_OP(mul);
398-
USE_OP_WITHOUT_KERNEL(recurrent_op);
398+
USE_OP_ITSELF(recurrent_op);

0 commit comments

Comments
 (0)