@@ -63,6 +63,17 @@ class ExecutionContext;
6363 */
6464class OperatorBase {
6565 public:
66+ OperatorBase () {} // TODO(yi): This constructor is to be removed.
67+ OperatorBase (const std::string& type, const std::vector<std::string>& inputs,
68+ const std::vector<std::string>& outputs,
69+ const AttributeMap& attrs,
70+ std::unordered_map<std::string, int >* in_out_idxs)
71+ : type_(type),
72+ inputs_ (inputs),
73+ outputs_(outputs),
74+ attrs_(attrs),
75+ in_out_idxs_(in_out_idxs) {}
76+
6677 virtual ~OperatorBase () {}
6778
6879 template <typename T>
@@ -109,6 +120,9 @@ class OperatorBase {
109120 const std::vector<std::string> Inputs () const { return inputs_; }
110121 const std::vector<std::string> Outputs () const { return outputs_; }
111122 const AttributeMap& Attrs () const { return attrs_; }
123+ const std::unordered_map<std::string, int >* InOutIdx () const {
124+ return in_out_idxs_.get ();
125+ }
112126
113127 public:
114128 std::string type_;
@@ -286,6 +300,14 @@ class OpKernel {
286300
287301class OperatorWithKernel : public OperatorBase {
288302 public:
303+ OperatorWithKernel () {} // TODO(yi): This constructor is to be removed.
304+ OperatorWithKernel (const std::string& type,
305+ const std::vector<std::string>& inputs,
306+ const std::vector<std::string>& outputs,
307+ const AttributeMap& attrs,
308+ std::unordered_map<std::string, int >* in_out_idxs)
309+ : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
310+
289311 struct OpKernelKey {
290312 platform::Place place_;
291313
@@ -335,5 +357,15 @@ class OperatorWithKernel : public OperatorBase {
335357 virtual void InferShape (const InferShapeContext& ctx) const = 0;
336358};
337359
360+ #define DEFINE_OPERATOR_CTOR (Class, ParentClass ) \
361+ public: \
362+ Class () { /* TODO(yi): This constructor is to be removed. */ \
363+ } \
364+ Class (const std::string& type, const std::vector<std::string>& inputs, \
365+ const std::vector<std::string>& outputs, \
366+ const ::paddle::framework::AttributeMap& attrs, \
367+ std::unordered_map<std::string, int >* in_out_idxs) \
368+ : ParentClass(type, inputs, outputs, attrs, in_out_idxs) {}
369+
338370} // namespace framework
339371} // namespace paddle
0 commit comments