- Notifications
You must be signed in to change notification settings - Fork 101
Support fuse bn into ConvTranspose. #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
44658d1
to 6e5ac70
Compare Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024 |
6e5ac70
to 5d4d388
Compare
@daquexian Done, please review. |
Signed-off-by: wenyuchi.wyc <wenyuchi.wyc@alibaba-inc.com>
5d4d388
to ff1229e
Compare Hello, i try to used this commit to fuse the bn layer and convtranspose layer in my model and find some bugs: From the doc of onnx website (https://onnx.ai/onnx/operators/onnx__ConvTranspose.html), the shape of weight array of convtranspose is (Cin, Cout, K, K), which is different to normal Conv layer (Cout, Cin, K, K). |
Hi, i would like to share my codes for fusing convtranspose and bn. It has been tested on my own model. I hope it will help others who have the same issue. import numpy as np import onnx import sclblonnx as so model = onnx.load('../onnx/models/backbone_clean.onnx') all_initializer = model.graph.initializer all_node = model.graph.node ConvTranspose_list = [] BatchNormalization_list = [] for i, node in enumerate(all_node): # search convtranspose and batchnormalization if node.op_type == "ConvTranspose": # print(i, node.name, node.op_type, node.input, node.output) ConvTranspose_list.append(node) if node.op_type == "BatchNormalization": # print(i, node.name, node.op_type, node.input, node.output) BatchNormalization_list.append(node) valid_ConvTranspose_list = [] for node in ConvTranspose_list: output = node.output for bn_node in BatchNormalization_list: bn_inputs = bn_node.input if output[0] in bn_inputs: valid_ConvTranspose_list.append({"conv": node, "bn": bn_node}) continue # print(valid_ConvTranspose_list) param_dict = {} for node in valid_ConvTranspose_list: conv = node["conv"] bn = node["bn"] # find params param_name = list(conv.input) + list(bn.input) for i, initializer in enumerate(all_initializer): if initializer.name in param_name: param_dict[initializer.name] = onnx.numpy_helper.to_array(initializer) # print(param_dict) for node in valid_ConvTranspose_list: conv = node["conv"] bn = node["bn"] bn_eps = bn.attribute[0].f bn_mom = bn.attribute[1].f bn_w = param_dict[bn.input[1]] # [Cout, ] bn_b = param_dict[bn.input[2]] # [Cout, ] bn_mean = param_dict[bn.input[3]] # [Cout, ] bn_var = param_dict[bn.input[4]] # [Cout, ] conv_w = param_dict[conv.input[1]] # [Cin, Cout, H, W] if len(conv.input) > 2: conv_b = param_dict[conv.input[2]] else: conv_b = np.zeros_like(bn_b) # [Cout, ] conv_w_tran = conv_w.transpose(1, 0, 2, 3) Cout = conv_w_tran.shape[0] conv_w_reshape = conv_w_tran.reshape([Cout, -1]) w_bn = np.diag(bn_w / (np.sqrt(bn_eps + bn_var))) new_conv_w = np.matmul(w_bn, conv_w_reshape).reshape(conv_w_tran.shape).transpose(1, 0, 2, 3) bn_b_tmp = bn_b - (np.multiply(bn_w, bn_mean) / (np.sqrt(bn_eps + bn_var))) new_conv_b = np.matmul(bn_w, conv_b) + bn_b_tmp new_node = onnx.helper.make_node( name=conv.name+'_bn', op_type="ConvTranspose", inputs=[conv.input[0], conv.name+'_bn.weights', conv.name+'_bn.bias'], outputs=[bn.output[0]], dilations=conv.attribute[0].ints, group=conv.attribute[1].i, kernel_shape=conv.attribute[2].ints, pads=conv.attribute[3].ints, strides=conv.attribute[4].ints ) initializer_w = onnx.helper.make_tensor( name=conv.name+'_bn.weights', data_type=onnx.helper.TensorProto.DataType.FLOAT, dims=new_conv_w.shape, vals=new_conv_w.tobytes(), raw=True ) initializer_b = onnx.helper.make_tensor( name=conv.name+'_bn.bias', data_type=onnx.helper.TensorProto.DataType.FLOAT, dims=new_conv_b.shape, vals=new_conv_b.tobytes(), raw=True ) model.graph.initializer.append(initializer_w) model.graph.initializer.append(initializer_b) # insert node for i, node in enumerate(all_node): if conv.name == node.name: model.graph.node.insert(i, new_node) break # clean node model.graph.node.remove(conv) model.graph.node.remove(bn) onnx.checker.check_model(model) onnx.save(model, '../onnx/models/backbone_fuse.onnx') graph = so.graph_from_file('../onnx/models/backbone_fuse.onnx') graph = so.clean(graph) so.check(graph) so.graph_to_file(graph, '../onnx/models/backbone_fuse.onnx') |
No description provided.