Skip to content

Commit 9e0c680

Browse files
committed
Python Generate OpCreation Methods by OpProto
All OpCreation method are generated by `create_op_creation_methods::__bootstrap__` method, and stores in `op_creations` object and its methods. There are three parts to implement this feature. 1. Get all registered `OpProto` from C++ side. It is implemented in `get_all_op_protos` method. 1. Create a function to convert `kwargs` to `OpDesc` base on each op's `OpProto`. The `OpDescCreationMethod` class. 1. Convert `OpProto` to `docstring` by `get_docstring_from_op_proto` method. All three methods are unit tested. The `__bootstrap__` just combines them together and create a method in runtime. For details, please reference the doc string in `create_op_creation_methods.py` and the unit test `test_op_creation_methods.py`.
1 parent 1faf5e0 commit 9e0c680

File tree

7 files changed

+539
-18
lines changed

7 files changed

+539
-18
lines changed

paddle/framework/op_registry.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <algorithm>
4+
#include <atomic>
45
#include <type_traits>
56
#include <unordered_map>
67
#include <unordered_set>
@@ -199,19 +200,31 @@ class OpRegistry {
199200
}
200201

201202
static OperatorPtr CreateOp(const OpDesc& op_desc) {
203+
//! Create a OpPtr by type.
202204
std::string op_type = op_desc.type();
203205
OperatorPtr op(creators().at(op_type)());
206+
207+
//! Fill op's data member. Not use constructor because it will be noising
208+
//! for Op developer.
204209
op->desc_ = op_desc;
205210
op->inputs_.reserve((size_t)op_desc.inputs_size());
206211
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
207212
std::back_inserter(op->inputs_));
208213
op->outputs_.reserve((size_t)op_desc.outputs_size());
209214
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
210215
std::back_inserter(op->outputs_));
216+
217+
//! Fill attrs, and validate attrs.
211218
for (auto& attr : op_desc.attrs()) {
212219
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
213220
}
214221
op_checkers().at(op_type).Check(op->attrs_);
222+
223+
//! Convert Temporary variable name to an unique variable name.
224+
AssignTempVariable(op.get());
225+
226+
//! Other op's custom Init for a complex Op. For simple Op, the Init
227+
//! method do nothing.
215228
op->Init();
216229
return op;
217230
}
@@ -222,6 +235,17 @@ class OpRegistry {
222235
};
223236

224237
private:
238+
static void AssignTempVariable(OperatorBase* op) {
239+
static std::atomic<size_t> gUniqId(0UL);
240+
for (auto& outname : op->outputs_) {
241+
if (outname == OperatorBase::TMP_VAR_NAME()) {
242+
outname += op->Type();
243+
outname += "@";
244+
outname += std::to_string(gUniqId.fetch_add(1));
245+
}
246+
}
247+
}
248+
225249
static std::unordered_map<std::string, OpCreator>& creators() {
226250
static std::unordered_map<std::string, OpCreator> creators_;
227251
return creators_;

paddle/framework/operator.cc

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,21 @@ namespace framework {
1919

2020
std::string OperatorBase::DebugString() const {
2121
std::stringstream ss;
22-
ss << "=================\n";
23-
ss << "type = " << desc_.type() << "\n";
24-
ss << "inputs = [";
25-
for (auto& ipt : inputs_) {
26-
ss << ipt << ", ";
22+
ss << "Op(" << Type() << "), inputs:(";
23+
for (size_t i = 0; i < inputs_.size(); ++i) {
24+
ss << inputs_[i];
25+
if (i != inputs_.size() - 1) {
26+
ss << ", ";
27+
}
2728
}
28-
ss << "]\n";
29-
ss << "outputs = [";
30-
for (auto& opt : outputs_) {
31-
ss << opt << ", ";
29+
ss << "), outputs:(";
30+
for (size_t i = 0; i < outputs_.size(); ++i) {
31+
ss << outputs_[i];
32+
if (i != outputs_.size() - 1) {
33+
ss << ", ";
34+
}
3235
}
33-
ss << "]\n";
34-
ss << "attr_keys = [";
35-
for (auto& attr : attrs_) {
36-
ss << attr.first << ", ";
37-
}
38-
ss << "]\n";
36+
ss << ").";
3937
return ss.str();
4038
}
4139

paddle/framework/operator.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ using OperatorPtr = std::shared_ptr<OperatorBase>;
3939
*/
4040
class OperatorBase {
4141
public:
42+
/// If a variable is a empty variable, that name will be used.
43+
static std::string EMPTY_VAR_NAME() { return "@EMPTY@"; }
44+
45+
/// If a variable is a temporary variable, that name will be set in Python,
46+
/// but it will be convert to a unique name in scope after OpCreator.
47+
static std::string TMP_VAR_NAME() { return "@TEMP@"; }
48+
4249
virtual ~OperatorBase() {}
4350

4451
template <typename T>
@@ -62,7 +69,6 @@ class OperatorBase {
6269
virtual void Run(const ScopePtr& scope,
6370
const platform::DeviceContext& dev_ctx) const = 0;
6471

65-
protected:
6672
std::string Type() const { return desc_.type(); }
6773

6874
public:

paddle/pybind/pybind.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@ All parameter, weight, gradient are variables in Paddle.
6363
}
6464
return ret_values;
6565
});
66+
m.def_submodule(
67+
"var_names",
68+
"The module will return special predefined variable name in Paddle")
69+
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
70+
.def("temp", pd::OperatorBase::TMP_VAR_NAME);
71+
72+
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
73+
.def("__str__", &pd::OperatorBase::DebugString)
74+
.def_static("create", [](const std::string& protobin) {
75+
pd::OpDesc desc;
76+
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
77+
"Cannot parse user input to OpDesc");
78+
PADDLE_ENFORCE(desc.IsInitialized(),
79+
"User OpDesc is not initialized, reason %s",
80+
desc.InitializationErrorString());
81+
return pd::OpRegistry::CreateOp(desc);
82+
});
6683

6784
return m.ptr();
6885
}
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,246 @@
11
import paddle.v2.framework.core as core
22
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
3+
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
4+
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
5+
import cStringIO
36

47

58
def get_all_op_protos():
9+
"""
10+
Get all registered op proto from Paddle C++
11+
:return: list of OpProto
12+
"""
613
protostrs = core.get_all_op_protos()
714
ret_values = []
815
for pbstr in protostrs:
916
op_proto = op_proto_pb2.OpProto.FromString(str(pbstr))
1017
ret_values.append(op_proto)
1118
return ret_values
19+
20+
21+
class OpDescCreationMethod(object):
22+
"""
23+
A Functor object to convert user input(use key word args) to OpDesc based on
24+
OpProto.
25+
26+
:param op_proto: The OpProto object.
27+
:type op_proto: op_proto_pb2.OpProto
28+
"""
29+
30+
def __init__(self, op_proto):
31+
if not isinstance(op_proto, op_proto_pb2.OpProto):
32+
raise TypeError("Argument should be OpProto")
33+
self.__op_proto__ = op_proto
34+
35+
def __call__(self, *args, **kwargs):
36+
"""
37+
Convert user input to OpDesc. Only key-word args are supported.
38+
:return: OpDesc based on user input
39+
:rtype: op_desc_pb2.OpDesc
40+
"""
41+
if len(args) != 0:
42+
raise ValueError("Only keyword arguments is supported by Paddle")
43+
op_desc = op_desc_pb2.OpDesc()
44+
45+
# Inputs
46+
ipts, ipt_format, _ = OpDescCreationMethod.extract_input_or_output(
47+
"input", kwargs, self.__op_proto__.inputs)
48+
op_desc.inputs.extend(ipts)
49+
if ipt_format is not None:
50+
op_desc.attrs.extend([ipt_format])
51+
52+
# Outputs
53+
outs, out_format, tmp_index = OpDescCreationMethod.extract_input_or_output(
54+
"output", kwargs, self.__op_proto__.outputs)
55+
op_desc.outputs.extend(outs)
56+
if out_format is not None:
57+
op_desc.attrs.extend([out_format])
58+
if len(tmp_index) != 0:
59+
tmp_index_attr = op_desc.attrs.add()
60+
tmp_index_attr.type = attr_type_pb2.INTS
61+
tmp_index_attr.name = "temporary_index"
62+
tmp_index_attr.ints.extend(tmp_index)
63+
64+
# Types
65+
op_desc.type = self.__op_proto__.type
66+
67+
# Attrs
68+
for attr in self.__op_proto__.attrs:
69+
if attr.generated:
70+
continue
71+
user_defined_attr = kwargs.get(attr.name, None)
72+
if user_defined_attr is not None:
73+
new_attr = op_desc.attrs.add()
74+
new_attr.name = attr.name
75+
new_attr.type = attr.type
76+
if attr.type == attr_type_pb2.INT:
77+
new_attr.i = user_defined_attr
78+
elif attr.type == attr_type_pb2.FLOAT:
79+
new_attr.f = user_defined_attr
80+
elif attr.type == attr_type_pb2.STRING:
81+
new_attr.s = user_defined_attr
82+
elif attr.type == attr_type_pb2.INTS:
83+
new_attr.ints.extend(user_defined_attr)
84+
elif attr.type == attr_type_pb2.FLOATS:
85+
new_attr.floats.extend(user_defined_attr)
86+
elif attr.type == attr_type_pb2.STRINGS:
87+
new_attr.strings.extend(user_defined_attr)
88+
else:
89+
raise NotImplementedError("Not support attribute type " +
90+
attr.type)
91+
92+
return op_desc
93+
94+
@staticmethod
95+
def extract_input_or_output(in_out, kwargs, meta):
96+
"""
97+
Extract input variable names or output variable names from key-word
98+
arguments, which base on VarProtos.
99+
100+
:param in_out: "input" or "output"
101+
:param kwargs: key-word arguments that user inputted.
102+
:param meta: a list of VarProto
103+
:return: The three object will be return. The variable names. The
104+
input_format or output_format attribute(None if the input or output is
105+
not multiple). The temporary variable index list.
106+
"""
107+
multiple = OpDescCreationMethod.any_is_true((m.multiple for m in meta))
108+
tmp_index = []
109+
retv = []
110+
if multiple:
111+
var_format = op_desc_pb2.AttrDesc()
112+
var_format.type = attr_type_pb2.INTS
113+
var_format.name = "%s_format" % in_out
114+
var_format.ints.append(0)
115+
116+
for var in meta:
117+
var_name = var.name
118+
119+
if var.temporary:
120+
var_name = [core.var_names.temp()]
121+
tmp_index.append(len(retv))
122+
else:
123+
var_name = kwargs.get(var_name, [])
124+
if not isinstance(var_name, list):
125+
var_name = [var_name]
126+
retv.extend(var_name)
127+
var_format.ints.append(len(var_name) + var_format.ints[-1])
128+
return retv, var_format, tmp_index
129+
else:
130+
for var in meta:
131+
if var.temporary:
132+
retv.append(kwargs.get(var.name, core.var_names.temp()))
133+
tmp_index.append(len(retv))
134+
else:
135+
retv.append(kwargs.get(var.name, core.var_names.empty()))
136+
return retv, None, tmp_index
137+
138+
@staticmethod
139+
def any_is_true(generator):
140+
"""
141+
Reduce a bool array to one. If any of them is True, then return True.
142+
"""
143+
for flag in generator:
144+
if flag:
145+
return True
146+
return False
147+
148+
149+
def get_docstring_from_op_proto(op_proto):
150+
"""
151+
Generate docstring from a OpProto
152+
:param op_proto: a OpProto instance.
153+
:type op_proto: op_proto_pb2.OpProto
154+
:return: docstring
155+
"""
156+
if not isinstance(op_proto, op_proto_pb2.OpProto):
157+
raise TypeError("Input must be OpProto")
158+
f = cStringIO.StringIO()
159+
f.write(op_proto.comment)
160+
f.write("\n")
161+
162+
def __append_param__(name, comment, type):
163+
# Maybe replace the following line with template engine is better.
164+
f.write(":param ")
165+
f.write(name)
166+
f.write(": ")
167+
f.write(comment)
168+
f.write("\n")
169+
f.write(":type ")
170+
f.write(name)
171+
f.write(": ")
172+
f.write(type)
173+
f.write("\n")
174+
175+
for ipt in op_proto.inputs:
176+
__append_param__(ipt.name, ipt.comment, "list | basestr"
177+
if ipt.multiple else "basestr")
178+
179+
temp_var_prefix = \
180+
"This is a temporary variable. It does not have to set by user. "
181+
for opt in op_proto.outputs:
182+
__append_param__(opt.name, opt.comment if not opt.temporary else
183+
temp_var_prefix + opt.comment, "list | basestr"
184+
if opt.multiple else "basestr")
185+
186+
for attr in op_proto.attrs:
187+
attr_type = None
188+
if attr.type == attr_type_pb2.INT:
189+
attr_type = "int"
190+
elif attr.type == attr_type_pb2.FLOAT:
191+
attr_type = "float"
192+
elif attr.type == attr_type_pb2.STRING:
193+
attr_type = "basestr"
194+
elif attr.type == attr_type_pb2.INTS:
195+
attr_type = "list of int"
196+
elif attr.type == attr_type_pb2.FLOATS:
197+
attr_type = "list of float"
198+
elif attr.type == attr_type_pb2.STRINGS:
199+
attr_type = "list of basestr"
200+
201+
if attr_type is None:
202+
raise RuntimeError("Not supported attribute type " + attr.type)
203+
204+
__append_param__(attr.name, attr.comment, attr_type)
205+
206+
return f.getvalue()
207+
208+
209+
def create_op_creation_method(op_proto):
210+
"""
211+
Generate op creation method for an OpProto
212+
"""
213+
method = OpDescCreationMethod(op_proto)
214+
215+
def __impl__(*args, **kwargs):
216+
opdesc = method(*args, **kwargs)
217+
return core.Operator.create(opdesc.SerializeToString())
218+
219+
__impl__.__doc__ = get_docstring_from_op_proto(op_proto)
220+
return __impl__
221+
222+
223+
class OpCreationsHolder(object):
224+
"""
225+
A object will holds all op creation methods.
226+
227+
Use `op_creations.xxx_op` to access them.
228+
"""
229+
pass
230+
231+
232+
op_creations = OpCreationsHolder()
233+
234+
235+
def __bootstrap__():
236+
"""
237+
Bootstrap function for this module. It will dynamic create all op creation
238+
methods in runtime.
239+
"""
240+
for op_proto in get_all_op_protos():
241+
func = create_op_creation_method(op_proto)
242+
func.__name__ = str(op_proto.type)
243+
setattr(op_creations, func.__name__, func)
244+
245+
246+
__bootstrap__()

0 commit comments

Comments
 (0)