Skip to content

Commit 6f28084

Browse files
authored
debug/format protobuf to human-readable codes (#8086)
1 parent f3d5923 commit 6f28084

File tree

1 file changed

+192
-0
lines changed

1 file changed

+192
-0
lines changed

python/paddle/v2/fluid/debuger.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,202 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import re
1617
from graphviz import GraphPreviewGenerator
1718
import proto.framework_pb2 as framework_pb2
1819

20+
_vartype2str_ = [
21+
"UNK",
22+
"LoDTensor",
23+
"SelectedRows",
24+
"FeedMinibatch",
25+
"FetchList",
26+
"StepScopes",
27+
"LodRankTable",
28+
"LoDTensorArray",
29+
"PlaceList",
30+
]
31+
_dtype2str_ = [
32+
"bool",
33+
"int16",
34+
"int32",
35+
"int64",
36+
"float16",
37+
"float32",
38+
"float64",
39+
]
40+
41+
42+
def repr_data_type(type):
43+
return _dtype2str_[type]
44+
45+
46+
def repr_tensor(proto):
47+
return "tensor(type={}, shape={})".format(_dtype2str_[int(proto.data_type)],
48+
str(proto.dims))
49+
50+
51+
reprtpl = "{ttype} {name} ({reprs})"
52+
53+
54+
def repr_lodtensor(proto):
55+
if not proto.lod_tensor: return
56+
level = proto.lod_tensor.lod_level
57+
reprs = repr_tensor(proto.lod_tensor.tensor)
58+
return reprtpl.format(
59+
ttype="LoDTensor" if level > 0 else "Tensor",
60+
name=proto.name,
61+
reprs="level=%d, %s" % (level, reprs) if level > 0 else reprs)
62+
63+
64+
def repr_selected_rows(proto):
65+
if not proto.selected_rows: return
66+
return reprtpl.format(
67+
ttype="SelectedRows",
68+
name=proto.name,
69+
reprs=repr_tensor(proto.selected_rows))
70+
71+
72+
def repr_tensor_array(proto):
73+
if not proto.tensor_array: return
74+
return reprtpl.format(
75+
ttype="TensorArray",
76+
name=proto.name,
77+
reprs="level=%d, %s" % (proto.tensor_array.lod_level,
78+
repr_tensor(proto.lod_tensor)))
79+
80+
81+
type_handlers = [
82+
repr_lodtensor,
83+
repr_selected_rows,
84+
repr_tensor_array,
85+
]
86+
87+
88+
def repr_var(vardesc):
89+
for handler in type_handlers:
90+
res = handler(vardesc)
91+
if res:
92+
return res
93+
94+
95+
def pprint_program_codes(program_desc):
96+
reprs = []
97+
for block_idx in range(program_desc.num_blocks()):
98+
block_desc = program_desc.block(block_idx)
99+
block_repr = pprint_block_codes(block_desc)
100+
reprs.append(block_repr)
101+
return '\n'.join(reprs)
102+
103+
104+
def pprint_block_codes(block_desc, show_backward=False):
105+
def is_op_backward(op_desc):
106+
if op_desc.type.endswith('_grad'): return True
107+
108+
def is_var_backward(var):
109+
if "@GRAD" in var.parameter: return True
110+
for arg in var.arguments:
111+
if "@GRAD" in arg: return True
112+
113+
for var in op_desc.inputs:
114+
if is_var_backward(var): return True
115+
for var in op_desc.outputs:
116+
if is_var_backward(var): return True
117+
return False
118+
119+
def is_var_backward(var_desc):
120+
return "@GRAD" in var_desc.name
121+
122+
if type(block_desc) is not framework_pb2.BlockDesc:
123+
block_desc = framework_pb2.BlockDesc.FromString(
124+
block_desc.serialize_to_string())
125+
var_reprs = []
126+
op_reprs = []
127+
for var in block_desc.vars:
128+
if not show_backward and is_var_backward(var):
129+
continue
130+
var_reprs.append(repr_var(var))
131+
132+
for op in block_desc.ops:
133+
if not show_backward and is_op_backward(op): continue
134+
op_reprs.append(repr_op(op))
135+
136+
tpl = "// block-{idx} parent-{pidx}\n// variables\n{vars}\n\n// operators\n{ops}\n"
137+
return tpl.format(
138+
idx=block_desc.idx,
139+
pidx=block_desc.parent_idx,
140+
vars='\n'.join(var_reprs),
141+
ops='\n'.join(op_reprs), )
142+
143+
144+
def repr_attr(desc):
145+
tpl = "{key}={value}"
146+
valgetter = [
147+
lambda attr: attr.i,
148+
lambda attr: attr.f,
149+
lambda attr: attr.s,
150+
lambda attr: attr.ints,
151+
lambda attr: attr.floats,
152+
lambda attr: attr.strings,
153+
lambda attr: attr.b,
154+
lambda attr: attr.bools,
155+
lambda attr: attr.block_idx,
156+
lambda attr: attr.l,
157+
]
158+
key = desc.name
159+
value = valgetter[desc.type](desc)
160+
if key == "dtype":
161+
value = repr_data_type(value)
162+
return tpl.format(key=key, value=str(value)), (key, value)
163+
164+
165+
def _repr_op_fill_constant(optype, inputs, outputs, attrs):
166+
if optype == "fill_constant":
167+
return "{output} = {data} [shape={shape}]".format(
168+
output=','.join(outputs),
169+
data=attrs['value'],
170+
shape=str(attrs['shape']))
171+
172+
173+
op_repr_handlers = [_repr_op_fill_constant, ]
174+
175+
176+
def repr_op(opdesc):
177+
optype = None
178+
attrs = []
179+
attr_dict = {}
180+
is_target = None
181+
inputs = []
182+
outputs = []
183+
184+
tpl = "{outputs} = {optype}({inputs}{is_target}) [{attrs}]"
185+
args2value = lambda args: args[0] if len(args) == 1 else str(list(args))
186+
for var in opdesc.inputs:
187+
key = var.parameter
188+
value = args2value(var.arguments)
189+
inputs.append("%s=%s" % (key, value))
190+
for var in opdesc.outputs:
191+
value = args2value(var.arguments)
192+
outputs.append(value)
193+
for attr in opdesc.attrs:
194+
attr_repr, attr_pair = repr_attr(attr)
195+
attrs.append(attr_repr)
196+
attr_dict[attr_pair[0]] = attr_pair[1]
197+
198+
is_target = opdesc.is_target
199+
200+
for handler in op_repr_handlers:
201+
res = handler(opdesc.type, inputs, outputs, attr_dict)
202+
if res: return res
203+
204+
return tpl.format(
205+
outputs=', '.join(outputs),
206+
optype=opdesc.type,
207+
inputs=', '.join(inputs),
208+
attrs="{%s}" % ','.join(attrs),
209+
is_target=", is_target" if is_target else "")
210+
19211

20212
def draw_block_graphviz(block, highlights=None, path="./temp.dot"):
21213
'''

0 commit comments

Comments
 (0)