|
1 | | -import core |
| 1 | +import contextlib |
| 2 | + |
2 | 3 | import proto.framework_pb2 as framework_pb2 |
| 4 | +import core |
3 | 5 | from framework import OpProtoHolder, Variable, Program, Operator |
4 | 6 | from initializer import Constant, Normal, Xavier, Initializer |
5 | 7 | from paddle.v2.fluid.layer_helper import LayerHelper, unique_name |
6 | | -import re |
7 | | -import cStringIO |
| 8 | +from registry import register_layer |
8 | 9 | from param_attr import ParamAttr |
9 | | -import contextlib |
10 | 10 |
|
11 | 11 | __all__ = [ |
12 | 12 | 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', |
13 | 13 | 'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'sums', 'cos_sim', |
14 | 14 | 'batch_norm', 'accuracy', 'split_lod_tensor', 'While' |
15 | 15 | ] |
16 | 16 |
|
| 17 | +_REGISTER_LAYER_FROM_OPS = [ |
| 18 | + 'mean', 'mul', 'elementwise_add', 'elementwise_div', 'dropout', 'reshape', |
| 19 | + 'sigmoid', 'scale', 'transpose', 'sigmoid_cross_entropy_with_logits' |
| 20 | +] |
| 21 | + |
| 22 | +for _OP in set(_REGISTER_LAYER_FROM_OPS): |
| 23 | + globals()[_OP] = register_layer(_OP) |
| 24 | + __all__.append(_OP) |
| 25 | + |
17 | 26 |
|
18 | 27 | def fc(input, |
19 | 28 | size, |
@@ -309,174 +318,6 @@ def create_tensor(dtype, name=None, main_program=None, startup_program=None): |
309 | 318 | return helper.create_variable(name=helper.name, dtype=dtype) |
310 | 319 |
|
311 | 320 |
|
312 | | -def _convert_(name): |
313 | | - """ |
314 | | - Formatting. |
315 | | -
|
316 | | - Args: |
317 | | - name: The name/alias |
318 | | -
|
319 | | - This function takes in a name and converts it to a standard format of |
320 | | - group1_group2. Where as per the regular expression, group1 can have |
321 | | - alphabets and numbers and group2 has capital alphabets. |
322 | | -
|
323 | | - """ |
324 | | - s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) |
325 | | - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() |
326 | | - |
327 | | - |
328 | | -def _generate_doc_string_(op_proto): |
329 | | - """ |
330 | | - Generate docstring by OpProto |
331 | | -
|
332 | | - Args: |
333 | | - op_proto (framework_pb2.OpProto): a protobuf message typed OpProto |
334 | | -
|
335 | | - Returns: |
336 | | - str: the document string |
337 | | - """ |
338 | | - |
339 | | - def _type_to_str_(tp): |
340 | | - return framework_pb2.AttrType.Name(tp) |
341 | | - |
342 | | - if not isinstance(op_proto, framework_pb2.OpProto): |
343 | | - raise TypeError("OpProto should be `framework_pb2.OpProto`") |
344 | | - |
345 | | - buf = cStringIO.StringIO() |
346 | | - buf.write(op_proto.comment) |
347 | | - buf.write('\nArgs:\n') |
348 | | - for each_input in op_proto.inputs: |
349 | | - line_begin = ' {0}: '.format(_convert_(each_input.name)) |
350 | | - buf.write(line_begin) |
351 | | - buf.write(each_input.comment) |
352 | | - buf.write('\n') |
353 | | - buf.write(' ' * len(line_begin)) |
354 | | - buf.write('Duplicable: ') |
355 | | - buf.write(str(each_input.duplicable)) |
356 | | - buf.write(' Optional: ') |
357 | | - buf.write(str(each_input.dispensable)) |
358 | | - buf.write('\n') |
359 | | - |
360 | | - for each_attr in op_proto.attrs: |
361 | | - buf.write(' ') |
362 | | - buf.write(each_attr.name) |
363 | | - buf.write(' (') |
364 | | - buf.write(_type_to_str_(each_attr.type)) |
365 | | - buf.write('): ') |
366 | | - buf.write(each_attr.comment) |
367 | | - buf.write('\n') |
368 | | - |
369 | | - if len(op_proto.outputs) != 0: |
370 | | - buf.write('\nReturns:\n') |
371 | | - buf.write(' ') |
372 | | - for each_opt in op_proto.outputs: |
373 | | - if not each_opt.intermediate: |
374 | | - break |
375 | | - buf.write(each_opt.comment) |
376 | | - |
377 | | - return buf.getvalue() |
378 | | - |
379 | | - |
380 | | -def _create_op_func_(op_type): |
381 | | - """ |
382 | | - Create an Operator for a Function. |
383 | | -
|
384 | | - Args: |
385 | | - op_type: The name of the operator to be created |
386 | | -
|
387 | | - This function takes in the operator type (sigmoid, mean , average etc) and |
388 | | - creates the operator functionality. |
389 | | -
|
390 | | - """ |
391 | | - op_proto = OpProtoHolder.instance().get_op_proto(op_type) |
392 | | - not_intermediate_outputs = \ |
393 | | - filter(lambda output: not output.intermediate, op_proto.outputs) |
394 | | - intermediate_outputs = \ |
395 | | - filter(lambda output: output.intermediate, op_proto.outputs) |
396 | | - |
397 | | - if len(not_intermediate_outputs) != 1: |
398 | | - raise ValueError("Only one non intermediate output operator can be", |
399 | | - "automatically generated") |
400 | | - |
401 | | - if not_intermediate_outputs[0].duplicable: |
402 | | - raise ValueError( |
403 | | - "Only non duplicable op can be automatically generated") |
404 | | - |
405 | | - for output in intermediate_outputs: |
406 | | - if output.duplicable: |
407 | | - raise ValueError("The op can be automatically generated only when ", |
408 | | - "all intermediate ops are not duplicable") |
409 | | - |
410 | | - o_name = not_intermediate_outputs[0].name |
411 | | - intermediate_output_names = [output.name for output in intermediate_outputs] |
412 | | - |
413 | | - def infer_and_check_dtype(op_proto, **kwargs): |
414 | | - """ |
415 | | - This function performs the sanity check for dtype and |
416 | | - instance type. |
417 | | - """ |
418 | | - dtype = None |
419 | | - for ipt in op_proto.inputs: |
420 | | - name = _convert_(ipt.name) |
421 | | - val = kwargs.pop(name, []) |
422 | | - if not isinstance(val, list) and not isinstance(val, tuple): |
423 | | - val = [val] |
424 | | - for each in val: |
425 | | - if not isinstance(each, Variable): |
426 | | - raise ValueError("input of {0} must be variable".format( |
427 | | - op_type)) |
428 | | - |
429 | | - if dtype is None: |
430 | | - dtype = each.dtype |
431 | | - elif dtype != each.dtype: |
432 | | - raise ValueError( |
433 | | - "operator {0} must input same dtype. {1} vs {2}".format( |
434 | | - op_type, dtype, each.dtype)) |
435 | | - |
436 | | - return dtype |
437 | | - |
438 | | - def func(**kwargs): |
439 | | - helper = LayerHelper(op_type, **kwargs) |
440 | | - |
441 | | - dtype = infer_and_check_dtype(op_proto, **kwargs) |
442 | | - |
443 | | - inputs = dict() |
444 | | - for ipt in op_proto.inputs: |
445 | | - name = _convert_(ipt.name) |
446 | | - val = kwargs.pop(name, []) |
447 | | - if not isinstance(val, list) and not isinstance(val, tuple): |
448 | | - val = [val] |
449 | | - inputs[ipt.name] = val |
450 | | - |
451 | | - outputs = dict() |
452 | | - out = helper.create_tmp_variable(dtype=dtype) |
453 | | - outputs[o_name] = [out] |
454 | | - for name in intermediate_output_names: |
455 | | - outputs[name] = [helper.create_tmp_variable(dtype=dtype)] |
456 | | - helper.append_op( |
457 | | - type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs) |
458 | | - return helper.append_activation(out) |
459 | | - |
460 | | - func.__name__ = op_type |
461 | | - globals()[op_type] = func |
462 | | - func.__doc__ = _generate_doc_string_(op_proto) |
463 | | - global __all__ |
464 | | - __all__.append(op_type) |
465 | | - |
466 | | - |
467 | | -_create_op_func_('mean') |
468 | | -_create_op_func_('mul') |
469 | | -_create_op_func_('elementwise_add') |
470 | | -_create_op_func_('elementwise_div') |
471 | | -_create_op_func_('dropout') |
472 | | -_create_op_func_('reshape') |
473 | | -_create_op_func_('sigmoid') |
474 | | -_create_op_func_('scale') |
475 | | -_create_op_func_('reshape') |
476 | | -_create_op_func_('transpose') |
477 | | -_create_op_func_('sigmoid_cross_entropy_with_logits') |
478 | | - |
479 | | - |
480 | 321 | def cast(x, dtype, main_program=None): |
481 | 322 | """ |
482 | 323 | This function takes in the input with input_dtype |
|
0 commit comments