Skip to content

Commit 697facc

Browse files
authored
"add registry interface" (#6449)
* "add registry interface" * "move function to registry" * "rename with meaningful name" * "add exposed layers" * "fixed based on comments" * "remove unsed comments"
1 parent 8ad36cd commit 697facc

File tree

3 files changed

+221
-172
lines changed

3 files changed

+221
-172
lines changed

python/paddle/v2/fluid/layers.py

Lines changed: 13 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
1-
import core
1+
import contextlib
2+
23
import proto.framework_pb2 as framework_pb2
4+
import core
35
from framework import OpProtoHolder, Variable, Program, Operator
46
from initializer import Constant, Normal, Xavier, Initializer
57
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
6-
import re
7-
import cStringIO
8+
from registry import register_layer
89
from param_attr import ParamAttr
9-
import contextlib
1010

1111
__all__ = [
1212
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
1313
'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'sums', 'cos_sim',
1414
'batch_norm', 'accuracy', 'split_lod_tensor', 'While'
1515
]
1616

17+
_REGISTER_LAYER_FROM_OPS = [
18+
'mean', 'mul', 'elementwise_add', 'elementwise_div', 'dropout', 'reshape',
19+
'sigmoid', 'scale', 'transpose', 'sigmoid_cross_entropy_with_logits'
20+
]
21+
22+
for _OP in set(_REGISTER_LAYER_FROM_OPS):
23+
globals()[_OP] = register_layer(_OP)
24+
__all__.append(_OP)
25+
1726

1827
def fc(input,
1928
size,
@@ -309,174 +318,6 @@ def create_tensor(dtype, name=None, main_program=None, startup_program=None):
309318
return helper.create_variable(name=helper.name, dtype=dtype)
310319

311320

312-
def _convert_(name):
313-
"""
314-
Formatting.
315-
316-
Args:
317-
name: The name/alias
318-
319-
This function takes in a name and converts it to a standard format of
320-
group1_group2. Where as per the regular expression, group1 can have
321-
alphabets and numbers and group2 has capital alphabets.
322-
323-
"""
324-
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
325-
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
326-
327-
328-
def _generate_doc_string_(op_proto):
329-
"""
330-
Generate docstring by OpProto
331-
332-
Args:
333-
op_proto (framework_pb2.OpProto): a protobuf message typed OpProto
334-
335-
Returns:
336-
str: the document string
337-
"""
338-
339-
def _type_to_str_(tp):
340-
return framework_pb2.AttrType.Name(tp)
341-
342-
if not isinstance(op_proto, framework_pb2.OpProto):
343-
raise TypeError("OpProto should be `framework_pb2.OpProto`")
344-
345-
buf = cStringIO.StringIO()
346-
buf.write(op_proto.comment)
347-
buf.write('\nArgs:\n')
348-
for each_input in op_proto.inputs:
349-
line_begin = ' {0}: '.format(_convert_(each_input.name))
350-
buf.write(line_begin)
351-
buf.write(each_input.comment)
352-
buf.write('\n')
353-
buf.write(' ' * len(line_begin))
354-
buf.write('Duplicable: ')
355-
buf.write(str(each_input.duplicable))
356-
buf.write(' Optional: ')
357-
buf.write(str(each_input.dispensable))
358-
buf.write('\n')
359-
360-
for each_attr in op_proto.attrs:
361-
buf.write(' ')
362-
buf.write(each_attr.name)
363-
buf.write(' (')
364-
buf.write(_type_to_str_(each_attr.type))
365-
buf.write('): ')
366-
buf.write(each_attr.comment)
367-
buf.write('\n')
368-
369-
if len(op_proto.outputs) != 0:
370-
buf.write('\nReturns:\n')
371-
buf.write(' ')
372-
for each_opt in op_proto.outputs:
373-
if not each_opt.intermediate:
374-
break
375-
buf.write(each_opt.comment)
376-
377-
return buf.getvalue()
378-
379-
380-
def _create_op_func_(op_type):
381-
"""
382-
Create an Operator for a Function.
383-
384-
Args:
385-
op_type: The name of the operator to be created
386-
387-
This function takes in the operator type (sigmoid, mean , average etc) and
388-
creates the operator functionality.
389-
390-
"""
391-
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
392-
not_intermediate_outputs = \
393-
filter(lambda output: not output.intermediate, op_proto.outputs)
394-
intermediate_outputs = \
395-
filter(lambda output: output.intermediate, op_proto.outputs)
396-
397-
if len(not_intermediate_outputs) != 1:
398-
raise ValueError("Only one non intermediate output operator can be",
399-
"automatically generated")
400-
401-
if not_intermediate_outputs[0].duplicable:
402-
raise ValueError(
403-
"Only non duplicable op can be automatically generated")
404-
405-
for output in intermediate_outputs:
406-
if output.duplicable:
407-
raise ValueError("The op can be automatically generated only when ",
408-
"all intermediate ops are not duplicable")
409-
410-
o_name = not_intermediate_outputs[0].name
411-
intermediate_output_names = [output.name for output in intermediate_outputs]
412-
413-
def infer_and_check_dtype(op_proto, **kwargs):
414-
"""
415-
This function performs the sanity check for dtype and
416-
instance type.
417-
"""
418-
dtype = None
419-
for ipt in op_proto.inputs:
420-
name = _convert_(ipt.name)
421-
val = kwargs.pop(name, [])
422-
if not isinstance(val, list) and not isinstance(val, tuple):
423-
val = [val]
424-
for each in val:
425-
if not isinstance(each, Variable):
426-
raise ValueError("input of {0} must be variable".format(
427-
op_type))
428-
429-
if dtype is None:
430-
dtype = each.dtype
431-
elif dtype != each.dtype:
432-
raise ValueError(
433-
"operator {0} must input same dtype. {1} vs {2}".format(
434-
op_type, dtype, each.dtype))
435-
436-
return dtype
437-
438-
def func(**kwargs):
439-
helper = LayerHelper(op_type, **kwargs)
440-
441-
dtype = infer_and_check_dtype(op_proto, **kwargs)
442-
443-
inputs = dict()
444-
for ipt in op_proto.inputs:
445-
name = _convert_(ipt.name)
446-
val = kwargs.pop(name, [])
447-
if not isinstance(val, list) and not isinstance(val, tuple):
448-
val = [val]
449-
inputs[ipt.name] = val
450-
451-
outputs = dict()
452-
out = helper.create_tmp_variable(dtype=dtype)
453-
outputs[o_name] = [out]
454-
for name in intermediate_output_names:
455-
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
456-
helper.append_op(
457-
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
458-
return helper.append_activation(out)
459-
460-
func.__name__ = op_type
461-
globals()[op_type] = func
462-
func.__doc__ = _generate_doc_string_(op_proto)
463-
global __all__
464-
__all__.append(op_type)
465-
466-
467-
_create_op_func_('mean')
468-
_create_op_func_('mul')
469-
_create_op_func_('elementwise_add')
470-
_create_op_func_('elementwise_div')
471-
_create_op_func_('dropout')
472-
_create_op_func_('reshape')
473-
_create_op_func_('sigmoid')
474-
_create_op_func_('scale')
475-
_create_op_func_('reshape')
476-
_create_op_func_('transpose')
477-
_create_op_func_('sigmoid_cross_entropy_with_logits')
478-
479-
480321
def cast(x, dtype, main_program=None):
481322
"""
482323
This function takes in the input with input_dtype

python/paddle/v2/fluid/registry.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import re
2+
import cStringIO
3+
import warnings
4+
import functools
5+
import inspect
6+
7+
import proto.framework_pb2 as framework_pb2
8+
from framework import OpProtoHolder, Variable, Program, Operator
9+
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
10+
11+
__all__ = ['deprecated', 'register_layer']
12+
13+
14+
def _convert_(name):
15+
"""
16+
Formatting.
17+
18+
Args:
19+
name: The name/alias
20+
21+
This function takes in a name and converts it to a standard format of
22+
group1_group2. Where as per the regular expression, group1 can have
23+
alphabets and numbers and group2 has capital alphabets.
24+
25+
"""
26+
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
27+
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
28+
29+
30+
def _generate_doc_string_(op_proto):
31+
"""
32+
Generate docstring by OpProto
33+
34+
Args:
35+
op_proto (framework_pb2.OpProto): a protobuf message typed OpProto
36+
37+
Returns:
38+
str: the document string
39+
"""
40+
41+
def _type_to_str_(tp):
42+
return framework_pb2.AttrType.Name(tp)
43+
44+
if not isinstance(op_proto, framework_pb2.OpProto):
45+
raise TypeError("OpProto should be `framework_pb2.OpProto`")
46+
47+
buf = cStringIO.StringIO()
48+
buf.write(op_proto.comment)
49+
buf.write('\nArgs:\n')
50+
for each_input in op_proto.inputs:
51+
line_begin = ' {0}: '.format(_convert_(each_input.name))
52+
buf.write(line_begin)
53+
buf.write(each_input.comment)
54+
buf.write('\n')
55+
buf.write(' ' * len(line_begin))
56+
buf.write('Duplicable: ')
57+
buf.write(str(each_input.duplicable))
58+
buf.write(' Optional: ')
59+
buf.write(str(each_input.dispensable))
60+
buf.write('\n')
61+
62+
for each_attr in op_proto.attrs:
63+
buf.write(' ')
64+
buf.write(each_attr.name)
65+
buf.write(' (')
66+
buf.write(_type_to_str_(each_attr.type))
67+
buf.write('): ')
68+
buf.write(each_attr.comment)
69+
buf.write('\n')
70+
71+
if len(op_proto.outputs) != 0:
72+
buf.write('\nReturns:\n')
73+
buf.write(' ')
74+
for each_opt in op_proto.outputs:
75+
if not each_opt.intermediate:
76+
break
77+
buf.write(each_opt.comment)
78+
79+
return buf.getvalue()
80+
81+
82+
def register_layer(op_type):
83+
"""
84+
Register an Python layer for an Operator
85+
86+
Args:
87+
op_type: The name of the operator to be created
88+
89+
This function takes in the operator type (sigmoid, mean , average etc) and
90+
creates the operator functionality.
91+
92+
"""
93+
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
94+
not_intermediate_outputs = \
95+
filter(lambda output: not output.intermediate, op_proto.outputs)
96+
intermediate_outputs = \
97+
filter(lambda output: output.intermediate, op_proto.outputs)
98+
99+
if len(not_intermediate_outputs) != 1:
100+
raise ValueError("Only one non intermediate output operator can be",
101+
"automatically generated")
102+
103+
if not_intermediate_outputs[0].duplicable:
104+
raise ValueError(
105+
"Only non duplicable op can be automatically generated")
106+
107+
for output in intermediate_outputs:
108+
if output.duplicable:
109+
raise ValueError("The op can be automatically generated only when ",
110+
"all intermediate ops are not duplicable")
111+
112+
o_name = not_intermediate_outputs[0].name
113+
intermediate_output_names = [output.name for output in intermediate_outputs]
114+
115+
def infer_and_check_dtype(op_proto, **kwargs):
116+
"""
117+
This function performs the sanity check for dtype and
118+
instance type.
119+
"""
120+
dtype = None
121+
for ipt in op_proto.inputs:
122+
name = _convert_(ipt.name)
123+
val = kwargs.pop(name, [])
124+
if not isinstance(val, list) and not isinstance(val, tuple):
125+
val = [val]
126+
for each in val:
127+
if not isinstance(each, Variable):
128+
raise ValueError("input of {0} must be variable".format(
129+
op_type))
130+
131+
if dtype is None:
132+
dtype = each.dtype
133+
elif dtype != each.dtype:
134+
raise ValueError(
135+
"operator {0} must input same dtype. {1} vs {2}".format(
136+
op_type, dtype, each.dtype))
137+
138+
return dtype
139+
140+
def func(**kwargs):
141+
helper = LayerHelper(op_type, **kwargs)
142+
143+
dtype = infer_and_check_dtype(op_proto, **kwargs)
144+
145+
inputs = dict()
146+
for ipt in op_proto.inputs:
147+
name = _convert_(ipt.name)
148+
val = kwargs.pop(name, [])
149+
if not isinstance(val, list) and not isinstance(val, tuple):
150+
val = [val]
151+
inputs[ipt.name] = val
152+
153+
outputs = dict()
154+
out = helper.create_tmp_variable(dtype=dtype)
155+
outputs[o_name] = [out]
156+
for name in intermediate_output_names:
157+
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
158+
helper.append_op(
159+
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
160+
return helper.append_activation(out)
161+
162+
func.__name__ = op_type
163+
func.__doc__ = _generate_doc_string_(op_proto)
164+
return func
165+
166+
167+
def deprecated(func_or_class):
168+
"""
169+
Deprecated warning decorator. It will result a warning message.
170+
Should be used before class or function, member function
171+
"""
172+
173+
@functools.wraps(func)
174+
def func_wrapper(*args, **kwargs):
175+
"""
176+
Wrap func with deprecated warning
177+
"""
178+
warnings.simplefilter('always', DeprecationWarning) #turn off filter
179+
warnings.warn(
180+
"Call to deprecated function {}.".format(func.__name__),
181+
category=DeprecationWarning,
182+
stacklevel=2)
183+
warnings.simplefilter('default', DeprecationWarning) #reset filter
184+
return func(*args, **kwargs)
185+
186+
return func_wrapper

0 commit comments

Comments
 (0)