|
1 | 1 | import paddle.v2.framework.core as core |
2 | 2 | import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 |
| 3 | +import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 |
| 4 | +import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2 |
| 5 | +import cStringIO |
3 | 6 |
|
4 | 7 |
|
5 | 8 | def get_all_op_protos(): |
| 9 | + """ |
| 10 | + Get all registered op proto from Paddle C++ |
| 11 | + :return: list of OpProto |
| 12 | + """ |
6 | 13 | protostrs = core.get_all_op_protos() |
7 | 14 | ret_values = [] |
8 | 15 | for pbstr in protostrs: |
9 | 16 | op_proto = op_proto_pb2.OpProto.FromString(str(pbstr)) |
10 | 17 | ret_values.append(op_proto) |
11 | 18 | return ret_values |
| 19 | + |
| 20 | + |
| 21 | +class OpDescCreationMethod(object): |
| 22 | + """ |
| 23 | + A Functor object to convert user input(use key word args) to OpDesc based on |
| 24 | + OpProto. |
| 25 | + |
| 26 | + :param op_proto: The OpProto object. |
| 27 | + :type op_proto: op_proto_pb2.OpProto |
| 28 | + """ |
| 29 | + |
| 30 | + def __init__(self, op_proto): |
| 31 | + if not isinstance(op_proto, op_proto_pb2.OpProto): |
| 32 | + raise TypeError("Argument should be OpProto") |
| 33 | + self.__op_proto__ = op_proto |
| 34 | + |
| 35 | + def __call__(self, *args, **kwargs): |
| 36 | + """ |
| 37 | + Convert user input to OpDesc. Only key-word args are supported. |
| 38 | + :return: OpDesc based on user input |
| 39 | + :rtype: op_desc_pb2.OpDesc |
| 40 | + """ |
| 41 | + if len(args) != 0: |
| 42 | + raise ValueError("Only keyword arguments is supported by Paddle") |
| 43 | + op_desc = op_desc_pb2.OpDesc() |
| 44 | + |
| 45 | + # Inputs |
| 46 | + ipts, ipt_format, _ = OpDescCreationMethod.extract_input_or_output( |
| 47 | + "input", kwargs, self.__op_proto__.inputs) |
| 48 | + op_desc.inputs.extend(ipts) |
| 49 | + if ipt_format is not None: |
| 50 | + op_desc.attrs.extend([ipt_format]) |
| 51 | + |
| 52 | + # Outputs |
| 53 | + outs, out_format, tmp_index = OpDescCreationMethod.extract_input_or_output( |
| 54 | + "output", kwargs, self.__op_proto__.outputs) |
| 55 | + op_desc.outputs.extend(outs) |
| 56 | + if out_format is not None: |
| 57 | + op_desc.attrs.extend([out_format]) |
| 58 | + if len(tmp_index) != 0: |
| 59 | + tmp_index_attr = op_desc.attrs.add() |
| 60 | + tmp_index_attr.type = attr_type_pb2.INTS |
| 61 | + tmp_index_attr.name = "temporary_index" |
| 62 | + tmp_index_attr.ints.extend(tmp_index) |
| 63 | + |
| 64 | + # Types |
| 65 | + op_desc.type = self.__op_proto__.type |
| 66 | + |
| 67 | + # Attrs |
| 68 | + for attr in self.__op_proto__.attrs: |
| 69 | + if attr.generated: |
| 70 | + continue |
| 71 | + user_defined_attr = kwargs.get(attr.name, None) |
| 72 | + if user_defined_attr is not None: |
| 73 | + new_attr = op_desc.attrs.add() |
| 74 | + new_attr.name = attr.name |
| 75 | + new_attr.type = attr.type |
| 76 | + if attr.type == attr_type_pb2.INT: |
| 77 | + new_attr.i = user_defined_attr |
| 78 | + elif attr.type == attr_type_pb2.FLOAT: |
| 79 | + new_attr.f = user_defined_attr |
| 80 | + elif attr.type == attr_type_pb2.STRING: |
| 81 | + new_attr.s = user_defined_attr |
| 82 | + elif attr.type == attr_type_pb2.INTS: |
| 83 | + new_attr.ints.extend(user_defined_attr) |
| 84 | + elif attr.type == attr_type_pb2.FLOATS: |
| 85 | + new_attr.floats.extend(user_defined_attr) |
| 86 | + elif attr.type == attr_type_pb2.STRINGS: |
| 87 | + new_attr.strings.extend(user_defined_attr) |
| 88 | + else: |
| 89 | + raise NotImplementedError("Not support attribute type " + |
| 90 | + attr.type) |
| 91 | + |
| 92 | + return op_desc |
| 93 | + |
| 94 | + @staticmethod |
| 95 | + def extract_input_or_output(in_out, kwargs, meta): |
| 96 | + """ |
| 97 | + Extract input variable names or output variable names from key-word |
| 98 | + arguments, which base on VarProtos. |
| 99 | + |
| 100 | + :param in_out: "input" or "output" |
| 101 | + :param kwargs: key-word arguments that user inputted. |
| 102 | + :param meta: a list of VarProto |
| 103 | + :return: The three object will be return. The variable names. The |
| 104 | + input_format or output_format attribute(None if the input or output is |
| 105 | + not multiple). The temporary variable index list. |
| 106 | + """ |
| 107 | + multiple = OpDescCreationMethod.any_is_true((m.multiple for m in meta)) |
| 108 | + tmp_index = [] |
| 109 | + retv = [] |
| 110 | + if multiple: |
| 111 | + var_format = op_desc_pb2.AttrDesc() |
| 112 | + var_format.type = attr_type_pb2.INTS |
| 113 | + var_format.name = "%s_format" % in_out |
| 114 | + var_format.ints.append(0) |
| 115 | + |
| 116 | + for var in meta: |
| 117 | + var_name = var.name |
| 118 | + |
| 119 | + if var.temporary: |
| 120 | + var_name = [core.var_names.temp()] |
| 121 | + tmp_index.append(len(retv)) |
| 122 | + else: |
| 123 | + var_name = kwargs.get(var_name, []) |
| 124 | + if not isinstance(var_name, list): |
| 125 | + var_name = [var_name] |
| 126 | + retv.extend(var_name) |
| 127 | + var_format.ints.append(len(var_name) + var_format.ints[-1]) |
| 128 | + return retv, var_format, tmp_index |
| 129 | + else: |
| 130 | + for var in meta: |
| 131 | + if var.temporary: |
| 132 | + retv.append(kwargs.get(var.name, core.var_names.temp())) |
| 133 | + tmp_index.append(len(retv)) |
| 134 | + else: |
| 135 | + retv.append(kwargs.get(var.name, core.var_names.empty())) |
| 136 | + return retv, None, tmp_index |
| 137 | + |
| 138 | + @staticmethod |
| 139 | + def any_is_true(generator): |
| 140 | + """ |
| 141 | + Reduce a bool array to one. If any of them is True, then return True. |
| 142 | + """ |
| 143 | + for flag in generator: |
| 144 | + if flag: |
| 145 | + return True |
| 146 | + return False |
| 147 | + |
| 148 | + |
| 149 | +def get_docstring_from_op_proto(op_proto): |
| 150 | + """ |
| 151 | + Generate docstring from a OpProto |
| 152 | + :param op_proto: a OpProto instance. |
| 153 | + :type op_proto: op_proto_pb2.OpProto |
| 154 | + :return: docstring |
| 155 | + """ |
| 156 | + if not isinstance(op_proto, op_proto_pb2.OpProto): |
| 157 | + raise TypeError("Input must be OpProto") |
| 158 | + f = cStringIO.StringIO() |
| 159 | + f.write(op_proto.comment) |
| 160 | + f.write("\n") |
| 161 | + |
| 162 | + def __append_param__(name, comment, type): |
| 163 | + # Maybe replace the following line with template engine is better. |
| 164 | + f.write(":param ") |
| 165 | + f.write(name) |
| 166 | + f.write(": ") |
| 167 | + f.write(comment) |
| 168 | + f.write("\n") |
| 169 | + f.write(":type ") |
| 170 | + f.write(name) |
| 171 | + f.write(": ") |
| 172 | + f.write(type) |
| 173 | + f.write("\n") |
| 174 | + |
| 175 | + for ipt in op_proto.inputs: |
| 176 | + __append_param__(ipt.name, ipt.comment, "list | basestr" |
| 177 | + if ipt.multiple else "basestr") |
| 178 | + |
| 179 | + temp_var_prefix = \ |
| 180 | + "This is a temporary variable. It does not have to set by user. " |
| 181 | + for opt in op_proto.outputs: |
| 182 | + __append_param__(opt.name, opt.comment if not opt.temporary else |
| 183 | + temp_var_prefix + opt.comment, "list | basestr" |
| 184 | + if opt.multiple else "basestr") |
| 185 | + |
| 186 | + for attr in op_proto.attrs: |
| 187 | + attr_type = None |
| 188 | + if attr.type == attr_type_pb2.INT: |
| 189 | + attr_type = "int" |
| 190 | + elif attr.type == attr_type_pb2.FLOAT: |
| 191 | + attr_type = "float" |
| 192 | + elif attr.type == attr_type_pb2.STRING: |
| 193 | + attr_type = "basestr" |
| 194 | + elif attr.type == attr_type_pb2.INTS: |
| 195 | + attr_type = "list of int" |
| 196 | + elif attr.type == attr_type_pb2.FLOATS: |
| 197 | + attr_type = "list of float" |
| 198 | + elif attr.type == attr_type_pb2.STRINGS: |
| 199 | + attr_type = "list of basestr" |
| 200 | + |
| 201 | + if attr_type is None: |
| 202 | + raise RuntimeError("Not supported attribute type " + attr.type) |
| 203 | + |
| 204 | + __append_param__(attr.name, attr.comment, attr_type) |
| 205 | + |
| 206 | + return f.getvalue() |
| 207 | + |
| 208 | + |
| 209 | +def create_op_creation_method(op_proto): |
| 210 | + """ |
| 211 | + Generate op creation method for an OpProto |
| 212 | + """ |
| 213 | + method = OpDescCreationMethod(op_proto) |
| 214 | + |
| 215 | + def __impl__(*args, **kwargs): |
| 216 | + opdesc = method(*args, **kwargs) |
| 217 | + return core.Operator.create(opdesc.SerializeToString()) |
| 218 | + |
| 219 | + __impl__.__doc__ = get_docstring_from_op_proto(op_proto) |
| 220 | + return __impl__ |
| 221 | + |
| 222 | + |
| 223 | +class OpCreationsHolder(object): |
| 224 | + """ |
| 225 | + A object will holds all op creation methods. |
| 226 | + |
| 227 | + Use `op_creations.xxx_op` to access them. |
| 228 | + """ |
| 229 | + pass |
| 230 | + |
| 231 | + |
| 232 | +op_creations = OpCreationsHolder() |
| 233 | + |
| 234 | + |
| 235 | +def __bootstrap__(): |
| 236 | + """ |
| 237 | + Bootstrap function for this module. It will dynamic create all op creation |
| 238 | + methods in runtime. |
| 239 | + """ |
| 240 | + for op_proto in get_all_op_protos(): |
| 241 | + func = create_op_creation_method(op_proto) |
| 242 | + func.__name__ = str(op_proto.type) |
| 243 | + setattr(op_creations, func.__name__, func) |
| 244 | + |
| 245 | + |
| 246 | +__bootstrap__() |
0 commit comments