|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import sys |
15 | 16 | import re |
16 | 17 | from graphviz import GraphPreviewGenerator |
17 | 18 | import proto.framework_pb2 as framework_pb2 |
18 | 19 |
|
| 20 | +_vartype2str_ = [ |
| 21 | + "UNK", |
| 22 | + "LoDTensor", |
| 23 | + "SelectedRows", |
| 24 | + "FeedMinibatch", |
| 25 | + "FetchList", |
| 26 | + "StepScopes", |
| 27 | + "LodRankTable", |
| 28 | + "LoDTensorArray", |
| 29 | + "PlaceList", |
| 30 | +] |
| 31 | +_dtype2str_ = [ |
| 32 | + "bool", |
| 33 | + "int16", |
| 34 | + "int32", |
| 35 | + "int64", |
| 36 | + "float16", |
| 37 | + "float32", |
| 38 | + "float64", |
| 39 | +] |
| 40 | + |
| 41 | + |
| 42 | +def repr_data_type(type): |
| 43 | + return _dtype2str_[type] |
| 44 | + |
| 45 | + |
| 46 | +def repr_tensor(proto): |
| 47 | + return "tensor(type={}, shape={})".format(_dtype2str_[int(proto.data_type)], |
| 48 | + str(proto.dims)) |
| 49 | + |
| 50 | + |
| 51 | +reprtpl = "{ttype} {name} ({reprs})" |
| 52 | + |
| 53 | + |
| 54 | +def repr_lodtensor(proto): |
| 55 | + if not proto.lod_tensor: return |
| 56 | + level = proto.lod_tensor.lod_level |
| 57 | + reprs = repr_tensor(proto.lod_tensor.tensor) |
| 58 | + return reprtpl.format( |
| 59 | + ttype="LoDTensor" if level > 0 else "Tensor", |
| 60 | + name=proto.name, |
| 61 | + reprs="level=%d, %s" % (level, reprs) if level > 0 else reprs) |
| 62 | + |
| 63 | + |
| 64 | +def repr_selected_rows(proto): |
| 65 | + if not proto.selected_rows: return |
| 66 | + return reprtpl.format( |
| 67 | + ttype="SelectedRows", |
| 68 | + name=proto.name, |
| 69 | + reprs=repr_tensor(proto.selected_rows)) |
| 70 | + |
| 71 | + |
| 72 | +def repr_tensor_array(proto): |
| 73 | + if not proto.tensor_array: return |
| 74 | + return reprtpl.format( |
| 75 | + ttype="TensorArray", |
| 76 | + name=proto.name, |
| 77 | + reprs="level=%d, %s" % (proto.tensor_array.lod_level, |
| 78 | + repr_tensor(proto.lod_tensor))) |
| 79 | + |
| 80 | + |
| 81 | +type_handlers = [ |
| 82 | + repr_lodtensor, |
| 83 | + repr_selected_rows, |
| 84 | + repr_tensor_array, |
| 85 | +] |
| 86 | + |
| 87 | + |
| 88 | +def repr_var(vardesc): |
| 89 | + for handler in type_handlers: |
| 90 | + res = handler(vardesc) |
| 91 | + if res: |
| 92 | + return res |
| 93 | + |
| 94 | + |
| 95 | +def pprint_program_codes(program_desc): |
| 96 | + reprs = [] |
| 97 | + for block_idx in range(program_desc.num_blocks()): |
| 98 | + block_desc = program_desc.block(block_idx) |
| 99 | + block_repr = pprint_block_codes(block_desc) |
| 100 | + reprs.append(block_repr) |
| 101 | + return '\n'.join(reprs) |
| 102 | + |
| 103 | + |
| 104 | +def pprint_block_codes(block_desc, show_backward=False): |
| 105 | + def is_op_backward(op_desc): |
| 106 | + if op_desc.type.endswith('_grad'): return True |
| 107 | + |
| 108 | + def is_var_backward(var): |
| 109 | + if "@GRAD" in var.parameter: return True |
| 110 | + for arg in var.arguments: |
| 111 | + if "@GRAD" in arg: return True |
| 112 | + |
| 113 | + for var in op_desc.inputs: |
| 114 | + if is_var_backward(var): return True |
| 115 | + for var in op_desc.outputs: |
| 116 | + if is_var_backward(var): return True |
| 117 | + return False |
| 118 | + |
| 119 | + def is_var_backward(var_desc): |
| 120 | + return "@GRAD" in var_desc.name |
| 121 | + |
| 122 | + if type(block_desc) is not framework_pb2.BlockDesc: |
| 123 | + block_desc = framework_pb2.BlockDesc.FromString( |
| 124 | + block_desc.serialize_to_string()) |
| 125 | + var_reprs = [] |
| 126 | + op_reprs = [] |
| 127 | + for var in block_desc.vars: |
| 128 | + if not show_backward and is_var_backward(var): |
| 129 | + continue |
| 130 | + var_reprs.append(repr_var(var)) |
| 131 | + |
| 132 | + for op in block_desc.ops: |
| 133 | + if not show_backward and is_op_backward(op): continue |
| 134 | + op_reprs.append(repr_op(op)) |
| 135 | + |
| 136 | + tpl = "// block-{idx} parent-{pidx}\n// variables\n{vars}\n\n// operators\n{ops}\n" |
| 137 | + return tpl.format( |
| 138 | + idx=block_desc.idx, |
| 139 | + pidx=block_desc.parent_idx, |
| 140 | + vars='\n'.join(var_reprs), |
| 141 | + ops='\n'.join(op_reprs), ) |
| 142 | + |
| 143 | + |
| 144 | +def repr_attr(desc): |
| 145 | + tpl = "{key}={value}" |
| 146 | + valgetter = [ |
| 147 | + lambda attr: attr.i, |
| 148 | + lambda attr: attr.f, |
| 149 | + lambda attr: attr.s, |
| 150 | + lambda attr: attr.ints, |
| 151 | + lambda attr: attr.floats, |
| 152 | + lambda attr: attr.strings, |
| 153 | + lambda attr: attr.b, |
| 154 | + lambda attr: attr.bools, |
| 155 | + lambda attr: attr.block_idx, |
| 156 | + lambda attr: attr.l, |
| 157 | + ] |
| 158 | + key = desc.name |
| 159 | + value = valgetter[desc.type](desc) |
| 160 | + if key == "dtype": |
| 161 | + value = repr_data_type(value) |
| 162 | + return tpl.format(key=key, value=str(value)), (key, value) |
| 163 | + |
| 164 | + |
| 165 | +def _repr_op_fill_constant(optype, inputs, outputs, attrs): |
| 166 | + if optype == "fill_constant": |
| 167 | + return "{output} = {data} [shape={shape}]".format( |
| 168 | + output=','.join(outputs), |
| 169 | + data=attrs['value'], |
| 170 | + shape=str(attrs['shape'])) |
| 171 | + |
| 172 | + |
| 173 | +op_repr_handlers = [_repr_op_fill_constant, ] |
| 174 | + |
| 175 | + |
| 176 | +def repr_op(opdesc): |
| 177 | + optype = None |
| 178 | + attrs = [] |
| 179 | + attr_dict = {} |
| 180 | + is_target = None |
| 181 | + inputs = [] |
| 182 | + outputs = [] |
| 183 | + |
| 184 | + tpl = "{outputs} = {optype}({inputs}{is_target}) [{attrs}]" |
| 185 | + args2value = lambda args: args[0] if len(args) == 1 else str(list(args)) |
| 186 | + for var in opdesc.inputs: |
| 187 | + key = var.parameter |
| 188 | + value = args2value(var.arguments) |
| 189 | + inputs.append("%s=%s" % (key, value)) |
| 190 | + for var in opdesc.outputs: |
| 191 | + value = args2value(var.arguments) |
| 192 | + outputs.append(value) |
| 193 | + for attr in opdesc.attrs: |
| 194 | + attr_repr, attr_pair = repr_attr(attr) |
| 195 | + attrs.append(attr_repr) |
| 196 | + attr_dict[attr_pair[0]] = attr_pair[1] |
| 197 | + |
| 198 | + is_target = opdesc.is_target |
| 199 | + |
| 200 | + for handler in op_repr_handlers: |
| 201 | + res = handler(opdesc.type, inputs, outputs, attr_dict) |
| 202 | + if res: return res |
| 203 | + |
| 204 | + return tpl.format( |
| 205 | + outputs=', '.join(outputs), |
| 206 | + optype=opdesc.type, |
| 207 | + inputs=', '.join(inputs), |
| 208 | + attrs="{%s}" % ','.join(attrs), |
| 209 | + is_target=", is_target" if is_target else "") |
| 210 | + |
19 | 211 |
|
20 | 212 | def draw_block_graphviz(block, highlights=None, path="./temp.dot"): |
21 | 213 | ''' |
|
0 commit comments