|
15 | 15 | import os |
16 | 16 | import argparse |
17 | 17 |
|
18 | | -from onnx import helper |
| 18 | +from onnx import helper, checker |
19 | 19 | import paddle.fluid as fluid |
20 | 20 |
|
21 | | -import ops |
22 | | -from variables import paddle_variable_to_onnx_tensor |
| 21 | +import fluid_onnx.ops as ops |
| 22 | +from fluid_onnx.variables import paddle_variable_to_onnx_tensor |
| 23 | +from fluid_onnx.variables import PADDLE_TO_ONNX_DTYPE |
23 | 24 |
|
24 | 25 |
|
25 | | -def convert(dirname): |
| 26 | +def parse_args(): |
| 27 | + # Read arguments: path to model. |
| 28 | + parser = argparse.ArgumentParser() |
| 29 | + parser.add_argument( |
| 30 | + "--fluid_model", required=True, help="Input PaddlePaddle Fluid model.") |
| 31 | + parser.add_argument( |
| 32 | + "--onnx_model", required=False, help="The path to save ONNX model.") |
| 33 | + args = parser.parse_args() |
| 34 | + return args |
| 35 | + |
| 36 | + |
| 37 | +def print_arguments(args): |
| 38 | + print('----------- Configuration Arguments -----------') |
| 39 | + for arg, value in sorted(vars(args).iteritems()): |
| 40 | + print('%s: %s' % (arg, value)) |
| 41 | + print('------------------------------------------------') |
| 42 | + |
| 43 | + |
| 44 | +def convert(args): |
26 | 45 | # Read the model files. |
27 | 46 | place = fluid.CPUPlace() |
28 | 47 | exe = fluid.Executor(place) |
29 | 48 |
|
30 | 49 | inference_scope = fluid.core.Scope() |
31 | 50 | with fluid.scope_guard(inference_scope): |
32 | 51 | [inference_program, feed_target_names, |
33 | | - fetch_targets] = fluid.io.load_inference_model(dirname, exe) |
| 52 | + fetch_targets] = fluid.io.load_inference_model(args.fluid_model, exe) |
34 | 53 |
|
35 | 54 | # Using blocks in programs, create nodes using: |
36 | 55 | onnx_nodes = [] |
37 | | - all_inputs = [] |
38 | | - for block in inference_program.blocks: |
39 | | - all_inputs += [ |
40 | | - paddle_variable_to_onnx_tensor(v, block) for v in block.vars |
41 | | - if v not in ['feed', 'fetch'] |
42 | | - ] |
43 | 56 |
|
| 57 | + # Load parameters |
| 58 | + global_block = inference_program.global_block() |
| 59 | + for var_name in global_block.vars: |
| 60 | + var = global_block.var(var_name) |
| 61 | + if var_name not in ['feed', 'fetch'] and var.persistable: |
| 62 | + param = fluid.executor.fetch_var(var_name, inference_scope) |
| 63 | + param_node = helper.make_node( |
| 64 | + 'Constant', |
| 65 | + inputs=[], |
| 66 | + outputs=[var_name], |
| 67 | + value=helper.make_tensor( |
| 68 | + name=var_name, |
| 69 | + dims=var.shape, |
| 70 | + data_type=PADDLE_TO_ONNX_DTYPE[var.dtype], |
| 71 | + vals=param.flatten().tolist())) |
| 72 | + onnx_nodes.append(param_node) |
| 73 | + |
| 74 | + # Create inputs |
| 75 | + inputs = [ |
| 76 | + paddle_variable_to_onnx_tensor(v, global_block) |
| 77 | + for v in feed_target_names |
| 78 | + ] |
| 79 | + |
| 80 | + # Create outputs |
| 81 | + fetch_target_names = [ |
| 82 | + fetch_target.name for fetch_target in fetch_targets |
| 83 | + ] |
| 84 | + outputs = [ |
| 85 | + paddle_variable_to_onnx_tensor(v, global_block) |
| 86 | + for v in fetch_target_names |
| 87 | + ] |
| 88 | + |
| 89 | + # Create nodes |
| 90 | + for block in inference_program.blocks: |
44 | 91 | for op in block.ops: |
45 | | - if op.type in ops.PADDLE_TO_ONNX: |
46 | | - # TODO(varunarora): Attributes. |
47 | | - # TODO(varunarora): Use the modifier function to make the |
48 | | - # transformation. |
49 | | - node_proto = helper.make_node( |
50 | | - ops.PADDLE_TO_ONNX[op.type][0], op.input_arg_names, |
51 | | - op.output_arg_names) |
| 92 | + if op.type in ops.node_maker: |
| 93 | + # TODO(kuke): deal with the corner case that vars in |
| 94 | + # different blocks have the same name |
| 95 | + node_proto = ops.node_maker[op.type]( |
| 96 | + inputs=op.input_arg_names, |
| 97 | + attrs=op.attr_names, |
| 98 | + outputs=op.output_arg_names) |
52 | 99 |
|
53 | 100 | onnx_nodes.append(node_proto) |
54 | 101 | else: |
55 | | - # Not valid to skip any op, so after all edge cases have |
56 | | - # been accounted for, this exception raising to be |
57 | | - # re-enabled. |
58 | | - # raise NameError(op.type) |
59 | | - pass |
| 102 | + if op.type not in ['feed', 'fetch']: |
| 103 | + raise NotImplementedError("OP[%s] is not supported in " |
| 104 | + "the converter!" % op.type) |
| 105 | + |
| 106 | + # Make graph |
| 107 | + model_name = os.path.basename(args.fluid_model.strip('/')).split('.')[0] |
| 108 | + onnx_graph = helper.make_graph(onnx_nodes, model_name, inputs, outputs) |
60 | 109 |
|
61 | | - # Nodes, name of graph, inputs, outputs. |
62 | | - if dirname[-1] == '/': |
63 | | - dirname = dirname[:-1] |
64 | | - graph = helper.make_graph(onnx_nodes, |
65 | | - os.path.basename(dirname).split('.')[0], |
66 | | - all_inputs, []) |
| 110 | + # Make model |
| 111 | + onnx_model = helper.make_model(onnx_graph, producer_name='PaddlePaddle') |
67 | 112 |
|
68 | | - print graph |
| 113 | + # Model check |
| 114 | + checker.check_model(onnx_model) |
69 | 115 |
|
70 | | - # TODO(varunarora): Plug in parameters. |
| 116 | + # Output readable model |
| 117 | + print("The converted model is:\n{}".format(onnx_model)) |
| 118 | + |
| 119 | + # Save converted model |
| 120 | + if args.onnx_model is not None: |
| 121 | + try: |
| 122 | + with open(args.onnx_model, 'wb') as f: |
| 123 | + f.write(onnx_model.SerializeToString()) |
| 124 | + print("Saved converted model to path: %s" % args.onnx_model) |
| 125 | + except (IOError), e: |
| 126 | + print("Invalid ONNX model saving path: %s" % args.onnx_model) |
71 | 127 |
|
72 | 128 |
|
73 | 129 | if __name__ == "__main__": |
74 | | - # Read arguments: path to model. |
75 | | - parser = argparse.ArgumentParser() |
76 | | - parser.add_argument( |
77 | | - "--modeldir", required=True, help="Input PaddlePaddle model") |
78 | | - args = parser.parse_args() |
79 | | - convert(args.modeldir) |
| 130 | + args = parse_args() |
| 131 | + print_arguments(args) |
| 132 | + convert(args) |
0 commit comments