@@ -14,44 +14,22 @@ limitations under the License. */
1414
1515#pragma  once
1616
17+ #include  < paddle/framework/attr_checker.h> 
18+ #include  < paddle/framework/op_desc.pb.h> 
19+ #include  < paddle/framework/scope.h> 
20+ #include  < paddle/platform/device_context.h> 
21+ #include  < paddle/platform/place.h> 
22+ #include  < paddle/utils/Error.h> 
1723#include  < boost/variant.hpp> 
1824#include  < string> 
1925#include  < unordered_map> 
2026#include  < vector> 
2127
22- #include  " paddle/framework/attr_checker.h" 
23- #include  " paddle/framework/op_desc.pb.h" 
24- #include  " paddle/framework/scope.h" 
25- #include  " paddle/utils/Error.h" 
26- 
2728namespace  paddle  {
2829namespace  framework  {
2930
3031class  OperatorBase ;
3132
32- class  DeviceContext  {};
33- 
34- /* *
35-  * OpRunContext is the only parameter of Operator's Run function. 
36-  * Run will get input/output variables, state such as momentum and 
37-  * device resource such as CUDA stream, cublas handle, etc. from 
38-  * OpRunContext. User should construct it before run the Operator. 
39-  */  
40- class  OpRunContext  {
41-  public: 
42-  OpRunContext (const  OperatorBase* op, const  std::shared_ptr<Scope> scope,
43-  const  DeviceContext* device_context)
44-  : op_(op), scope_(scope), device_context_(device_context) {}
45- 
46-  const  Variable* Input (int  index) const ;
47-  Variable* Output (int  index) const ;
48- 
49-  public: 
50-  const  OperatorBase* op_;
51-  const  std::shared_ptr<Scope> scope_;
52-  const  DeviceContext* device_context_;
53- };
54- 
5533/* *
5634 * OperatorBase has the basic element that Net will call to do computation. 
5735 * Only CreateOperator from OpRegistry will new Operator directly. User 
@@ -77,7 +55,10 @@ class OperatorBase {
7755
7856 // / Net will call this function to Run an op.
7957 virtual  void  Run (const  std::shared_ptr<Scope>& scope,
80-  const  DeviceContext* dev_ctx) const  = 0;
58+  const  platform::DeviceContext& dev_ctx) const  = 0;
59+ 
60+  protected: 
61+  std::string Type () const  { return  desc_.type (); }
8162
8263 public: 
8364 OpDesc desc_;
@@ -86,22 +67,84 @@ class OperatorBase {
8667 AttributeMap attrs_;
8768};
8869
70+ class  OpKernel  {
71+  public: 
72+  /* *
73+  * KernelContext is the only parameter of Kernel Run function. 
74+  * Run will get input/output variables, state such as momentum and 
75+  * device resource such as CUDA stream, cublas handle, etc. from 
76+  * KernelContext. User should construct it before run the Operator. 
77+  */  
78+  class  KernelContext  {
79+  public: 
80+  KernelContext (const  OperatorBase* op, const  std::shared_ptr<Scope>& scope,
81+  const  platform::DeviceContext& device_context)
82+  : op_(*op), scope_(scope), device_context_(device_context) {}
83+ 
84+  const  Variable* Input (int  index) const  {
85+  return  scope_->GetVariable (op_.inputs_ [index]);
86+  }
87+ 
88+  Variable* Output (int  index) const  {
89+  return  scope_->GetVariable (op_.outputs_ [index]);
90+  }
91+ 
92+  const  OperatorBase& op_;
93+  const  std::shared_ptr<Scope>& scope_;
94+  const  platform::DeviceContext& device_context_;
95+  };
96+ 
97+  virtual  void  Compute (const  KernelContext& context) const  = 0;
98+ 
99+  virtual  ~OpKernel () {}
100+ };
101+ 
89102class  OperatorWithKernel  : public  OperatorBase  {
90103 public: 
91-  virtual  ~OperatorWithKernel () {}
104+  struct  OpKernelKey  {
105+  platform::Place place_;
92106
93-  virtual  void  InferShape (const  std::shared_ptr<Scope>& scope) const  {}
107+  OpKernelKey () = default ;
108+  OpKernelKey (const  platform::DeviceContext& dev_ctx) {
109+  place_ = dev_ctx.GetPlace ();
110+  }
111+ 
112+  bool  operator ==(const  OpKernelKey& o) const  { return  place_ == o.place_ ; }
113+  };
114+ 
115+  struct  OpKernelHash  {
116+  std::hash<bool > hash_;
117+  size_t  operator ()(const  OpKernelKey& key) const  {
118+  return  hash_ (platform::is_gpu_place (key.place_ ));
119+  }
120+  };
121+ 
122+  using  OpKernelMap =
123+  std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
94124
95125 void  Run (const  std::shared_ptr<Scope>& scope,
96-  const  DeviceContext*  dev_ctx) const  {
97-  OpRunContext  op_ctx ( this , scope,  dev_ctx);
98-  Run (&op_ctx );
126+  const  platform:: DeviceContext&  dev_ctx) const   final  {
127+  auto & opKernel =  AllOpKernels (). at ( Type ()). at ( OpKernelKey ( dev_ctx) );
128+  opKernel-> Compute ( OpKernel::KernelContext ( this , scope, dev_ctx) );
99129 }
100130
101-  // / when implement an Op, your should implement this function.
102-  // / this function should be moved to OpKernel later
103-  virtual  void  Run (const  OpRunContext* context) const  = 0;
131+  static  std::unordered_map<std::string /*  op_type */ 
132+  AllOpKernels () {
133+  static  std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
134+  return  g_all_op_kernels;
135+  };
104136};
105137
106138} //  namespace framework
107139} //  namespace paddle
140+ 
141+ #define  REGISTER_OP_KERNEL (type, PlaceType, KernelType ) \
142+  struct  __op_kernel_register__ ##type##__ { \
143+  __op_kernel_register__##type##__() { \
144+  ::paddle::framework::OperatorWithKernel::OpKernelKey key; \
145+  key.place_  = PlaceType (); \
146+  ::paddle::framework::OperatorWithKernel::AllOpKernels ()[#type][key] \
147+  .reset (new  KernelType ()); \
148+  } \
149+  }; \
150+  static  __op_kernel_register__##type##__ __reg_kernel_##type##__
0 commit comments