Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 90 additions & 37 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,65 +15,118 @@
import os
import argparse

from onnx import helper
from onnx import helper, checker
import paddle.fluid as fluid

import ops
from variables import paddle_variable_to_onnx_tensor
import fluid_onnx.ops as ops
from fluid_onnx.variables import paddle_variable_to_onnx_tensor
from fluid_onnx.variables import PADDLE_TO_ONNX_DTYPE


def convert(dirname):
def parse_args():
# Read arguments: path to model.
parser = argparse.ArgumentParser()
parser.add_argument(
"--fluid_model", required=True, help="Input PaddlePaddle Fluid model.")
parser.add_argument(
"--onnx_model", required=False, help="The path to save ONNX model.")
args = parser.parse_args()
return args


def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')


def convert(args):
# Read the model files.
place = fluid.CPUPlace()
exe = fluid.Executor(place)

inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(dirname, exe)
fetch_targets] = fluid.io.load_inference_model(args.fluid_model, exe)

# Using blocks in programs, create nodes using:
onnx_nodes = []
all_inputs = []
for block in inference_program.blocks:
all_inputs += [
paddle_variable_to_onnx_tensor(v, block) for v in block.vars
if v not in ['feed', 'fetch']
]

# Load parameters
global_block = inference_program.global_block()
for var_name in global_block.vars:
var = global_block.var(var_name)
if var_name not in ['feed', 'fetch'] and var.persistable:
param = fluid.executor.fetch_var(var_name, inference_scope)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh man this is great, I spent so much time writing a custom (py)binding to deserialize params manually, I'll throw that stuff for now. fetch_var is beautiful. You are the expert

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Actually I am also not familiar with this part, and just find this method by chance :-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So let me tell why I have been trying to find overly complex solutions to this. If you see this, it clearly seems like scopes get destroyed after their runs - which is fair, we shouldnt have these thousands, sometimes millions, of variables in memory if they are unused. The only things saved from destruction are persistable global block things (which you seem to be using here). So vars not in global blocks can't be fetched because everything other than global block vars are destroyed.

Now on second thoughts, we can go through the global block only because we don't care about inner blocks, given our use case. But we should be careful because fetch_var would not have been an option if we had cared about them.

It's also important to be careful because the suggested way to fetch anything out of a program run is to add it to the fetch_list argument - and not fetch_var. But to be able to use that, we can't use the default load_inference_model model function - we'd need to tweak it to pass in and return fetch_list we are interested in.

Copy link
Author

@kuke kuke Apr 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you have a much deeper survey than me. It is not easy to realize the existence of fetch_var because common users usually don't have the necessity to use it.

param_node = helper.make_node(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So where did you figure out that this was the way params need to be set in an ONNX model? I have been confused and asking around

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also didn't find the way to load parameters until I checked the onnx models in onnx/models, and found that they use Constant op to store parameter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting. I asked this question on the ONNX gitter chat and searched a lot through their issues. So (a) your solution is good based on the info you have. (b) it might be a good idea to use what the practice they are recommending. Which is use the initializers argument on the make_model function. More here: https://github.com/onnx/onnxmltools/tree/master/onnxmltools/convert/common and https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/coreml/NeuralNetwork/fullyconnected.py#L33. So basically collect the values while traversing the variables, and then when calling make_model, pass them in.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an important information. I think that we can keep the current implementation and move on. In future we can try the initializers to see if it works.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, parameter is a common concept, so we should make this clear.

'Constant',
inputs=[],
outputs=[var_name],
value=helper.make_tensor(
name=var_name,
dims=var.shape,
data_type=PADDLE_TO_ONNX_DTYPE[var.dtype],
vals=param.flatten().tolist()))
onnx_nodes.append(param_node)

# Create inputs
inputs = [
paddle_variable_to_onnx_tensor(v, global_block)
for v in feed_target_names
]

# Create outputs
fetch_target_names = [
fetch_target.name for fetch_target in fetch_targets
]
outputs = [
paddle_variable_to_onnx_tensor(v, global_block)
for v in fetch_target_names
]

# Create nodes
for block in inference_program.blocks:
for op in block.ops:
if op.type in ops.PADDLE_TO_ONNX:
# TODO(varunarora): Attributes.
# TODO(varunarora): Use the modifier function to make the
# transformation.
node_proto = helper.make_node(
ops.PADDLE_TO_ONNX[op.type][0], op.input_arg_names,
op.output_arg_names)
if op.type in ops.node_maker:
# TODO(kuke): deal with the corner case that vars in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't unique names generated, no matter how local the scope is? You probably know more here..

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, seems that duplicated names are allowed for local variables. I am discussing with others working on framework development, to figure out a proper solution. So if you have any idea, please feel free to speak out.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No ideas here, you can leave the comment as is, but I don't think you should worry about a conflict here

Copy link
Author

@kuke kuke Apr 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe that we can figure a way to solve this conflicts. And should always be careful about the potential risk before it is solved totally.

# different blocks have the same name
node_proto = ops.node_maker[op.type](
inputs=op.input_arg_names,
attrs=op.attr_names,
outputs=op.output_arg_names)

onnx_nodes.append(node_proto)
else:
# Not valid to skip any op, so after all edge cases have
# been accounted for, this exception raising to be
# re-enabled.
# raise NameError(op.type)
pass
if op.type not in ['feed', 'fetch']:
raise NotImplementedError("OP[%s] is not supported in "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much smarter error than NameError :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe a better one. You remind me to raise an error here ;-).

"the converter!" % op.type)

# Make graph
model_name = os.path.basename(args.fluid_model.strip('/')).split('.')[0]
onnx_graph = helper.make_graph(onnx_nodes, model_name, inputs, outputs)

# Nodes, name of graph, inputs, outputs.
if dirname[-1] == '/':
dirname = dirname[:-1]
graph = helper.make_graph(onnx_nodes,
os.path.basename(dirname).split('.')[0],
all_inputs, [])
# Make model
onnx_model = helper.make_model(onnx_graph, producer_name='PaddlePaddle')

print graph
# Model check
checker.check_model(onnx_model)

# TODO(varunarora): Plug in parameters.
# Output readable model
print("The converted model is:\n{}".format(onnx_model))

# Save converted model
if args.onnx_model is not None:
try:
with open(args.onnx_model, 'wb') as f:
f.write(onnx_model.SerializeToString())
print("Saved converted model to path: %s" % args.onnx_model)
except (IOError), e:
print("Invalid ONNX model saving path: %s" % args.onnx_model)


if __name__ == "__main__":
# Read arguments: path to model.
parser = argparse.ArgumentParser()
parser.add_argument(
"--modeldir", required=True, help="Input PaddlePaddle model")
args = parser.parse_args()
convert(args.modeldir)
args = parse_args()
print_arguments(args)
convert(args)
13 changes: 13 additions & 0 deletions fluid_onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16 changes: 9 additions & 7 deletions ops.py → fluid_onnx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from onnx.helper import make_node
"""
Priority of ops (uniques) to figure out support for.

Expand Down Expand Up @@ -53,8 +55,8 @@ def abs_op():
pass


def add_op():
pass
def add_op(inputs, attrs, outputs):
return make_node('Add', inputs=inputs, outputs=outputs, broadcast=1)


def and_op():
Expand Down Expand Up @@ -222,8 +224,8 @@ def lppool_op():
pass


def matmul_op():
pass
def matmul_op(inputs, attrs, outputs):
return make_node('MatMul', inputs=inputs, outputs=outputs)


def max_op():
Expand Down Expand Up @@ -445,10 +447,10 @@ def xor_op():
# ONNX Ops that use multiple Paddle ops are keyed by '<op1>,<op2>' fed into the
# modifier.

PADDLE_TO_ONNX = {
node_maker = {
# Paddle op name : (ONNX op name, modifier)
'abs': ('Abs', abs_op),
'elementwise_add': ('Add', add_op),
'elementwise_add': add_op,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the reason I originally did this weird mapping to tuples is because we may want to reuse such a map in future for mapping the reverse way. So it will be easier to traverse that. If you think we will probably need to create an entirely way, this makes much more obvious sense. What do you think?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we can implement the bidirectional conversion in one map. But after I write some code, I feel it would be better to decouple the two conversions very clearly. Because it is a bit hard to make the two sets of operators one-to-one correspondence. For example, ONNX has FC operator, but Fluid doesn't, in Fluid--> ONNX conversion we would never use FC op, and in the reverse conversion we need to implement the FC op with mul and elementwise_add op in Fluid.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely fair, let's go with this. I thought of the same, but yeah, I agree with your conclusion


# '': 'And', # ?
# 'ArgMax', NEEDS ATTENTION.
Expand Down Expand Up @@ -496,7 +498,7 @@ def xor_op():
'': 'MaxRoiPool',
'mean': ('Mean', mean_op),
'': 'Min',
'mul': ('Mul', mul_op),
'mul': matmul_op,
',': 'Neg',
'': 'Not',
'': 'Or',
Expand Down
16 changes: 13 additions & 3 deletions variables.py → fluid_onnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from onnx import helper, onnx_pb2, TensorProto
import paddle.fluid.core as core


def paddle_variable_to_onnx_tensor(paddle_var_name, block):
# TODO(varunarora): Need to do this only in the case of VarType.LOD_TENSOR.
paddle_var = block.var(paddle_var_name)
return helper.make_tensor_value_info(paddle_var_name,
PADDLE_TO_ONNX_DTYPE[paddle_var.dtype],
paddle_var.shape)
shape = paddle_onnx_shape(paddle_var.shape)
return helper.make_tensor_value_info(
paddle_var_name, PADDLE_TO_ONNX_DTYPE[paddle_var.dtype], shape)


def paddle_onnx_shape(paddle_shape):
""" Convert shape info from paddle to onnx
"""

onnx_shape = np.array(list(paddle_shape))
onnx_shape[onnx_shape < 0] = 0
return tuple(onnx_shape)


PADDLE_TO_ONNX_DTYPE = {
Expand Down