Skip to content

Conversation

YuchiWen
Copy link

No description provided.

@YuchiWen YuchiWen force-pushed the fuse_bn_into_conv_transpose branch from 44658d1 to 6e5ac70 Compare November 22, 2022 04:01
@daquexian
Copy link
Member

Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024

@YuchiWen YuchiWen force-pushed the fuse_bn_into_conv_transpose branch from 6e5ac70 to 5d4d388 Compare March 6, 2023 10:52
@YuchiWen
Copy link
Author

YuchiWen commented Mar 6, 2023

Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024

@daquexian Done, please review.

Signed-off-by: wenyuchi.wyc <wenyuchi.wyc@alibaba-inc.com>
@YuchiWen YuchiWen force-pushed the fuse_bn_into_conv_transpose branch from 5d4d388 to ff1229e Compare March 6, 2023 11:31
@huangzhicong3
Copy link

Hello, i try to used this commit to fuse the bn layer and convtranspose layer in my model and find some bugs:
The error message is:
passes/fuse_bn_into_conv.h:71: modify_conv: Assertion conv_W.sizes().size() > 2 && conv_W.sizes()[0] == C failed.

From the doc of onnx website (https://onnx.ai/onnx/operators/onnx__ConvTranspose.html), the shape of weight array of convtranspose is (Cin, Cout, K, K), which is different to normal Conv layer (Cout, Cin, K, K).

@huangzhicong3
Copy link

huangzhicong3 commented May 30, 2024

Hi, i would like to share my codes for fusing convtranspose and bn. It has been tested on my own model. I hope it will help others who have the same issue.

import numpy as np import onnx import sclblonnx as so model = onnx.load('../onnx/models/backbone_clean.onnx') all_initializer = model.graph.initializer all_node = model.graph.node ConvTranspose_list = [] BatchNormalization_list = [] for i, node in enumerate(all_node): # search convtranspose and batchnormalization if node.op_type == "ConvTranspose": # print(i, node.name, node.op_type, node.input, node.output) ConvTranspose_list.append(node) if node.op_type == "BatchNormalization": # print(i, node.name, node.op_type, node.input, node.output) BatchNormalization_list.append(node) valid_ConvTranspose_list = [] for node in ConvTranspose_list: output = node.output for bn_node in BatchNormalization_list: bn_inputs = bn_node.input if output[0] in bn_inputs: valid_ConvTranspose_list.append({"conv": node, "bn": bn_node}) continue # print(valid_ConvTranspose_list) param_dict = {} for node in valid_ConvTranspose_list: conv = node["conv"] bn = node["bn"] # find params param_name = list(conv.input) + list(bn.input) for i, initializer in enumerate(all_initializer): if initializer.name in param_name: param_dict[initializer.name] = onnx.numpy_helper.to_array(initializer) # print(param_dict) for node in valid_ConvTranspose_list: conv = node["conv"] bn = node["bn"] bn_eps = bn.attribute[0].f bn_mom = bn.attribute[1].f bn_w = param_dict[bn.input[1]] # [Cout, ] bn_b = param_dict[bn.input[2]] # [Cout, ] bn_mean = param_dict[bn.input[3]] # [Cout, ] bn_var = param_dict[bn.input[4]] # [Cout, ] conv_w = param_dict[conv.input[1]] # [Cin, Cout, H, W] if len(conv.input) > 2: conv_b = param_dict[conv.input[2]] else: conv_b = np.zeros_like(bn_b) # [Cout, ] conv_w_tran = conv_w.transpose(1, 0, 2, 3) Cout = conv_w_tran.shape[0] conv_w_reshape = conv_w_tran.reshape([Cout, -1]) w_bn = np.diag(bn_w / (np.sqrt(bn_eps + bn_var))) new_conv_w = np.matmul(w_bn, conv_w_reshape).reshape(conv_w_tran.shape).transpose(1, 0, 2, 3) bn_b_tmp = bn_b - (np.multiply(bn_w, bn_mean) / (np.sqrt(bn_eps + bn_var))) new_conv_b = np.matmul(bn_w, conv_b) + bn_b_tmp new_node = onnx.helper.make_node( name=conv.name+'_bn', op_type="ConvTranspose", inputs=[conv.input[0], conv.name+'_bn.weights', conv.name+'_bn.bias'], outputs=[bn.output[0]], dilations=conv.attribute[0].ints, group=conv.attribute[1].i, kernel_shape=conv.attribute[2].ints, pads=conv.attribute[3].ints, strides=conv.attribute[4].ints ) initializer_w = onnx.helper.make_tensor( name=conv.name+'_bn.weights', data_type=onnx.helper.TensorProto.DataType.FLOAT, dims=new_conv_w.shape, vals=new_conv_w.tobytes(), raw=True ) initializer_b = onnx.helper.make_tensor( name=conv.name+'_bn.bias', data_type=onnx.helper.TensorProto.DataType.FLOAT, dims=new_conv_b.shape, vals=new_conv_b.tobytes(), raw=True ) model.graph.initializer.append(initializer_w) model.graph.initializer.append(initializer_b) # insert node for i, node in enumerate(all_node): if conv.name == node.name: model.graph.node.insert(i, new_node) break # clean node model.graph.node.remove(conv) model.graph.node.remove(bn) onnx.checker.check_model(model) onnx.save(model, '../onnx/models/backbone_fuse.onnx') graph = so.graph_from_file('../onnx/models/backbone_fuse.onnx') graph = so.clean(graph) so.check(graph) so.graph_to_file(graph, '../onnx/models/backbone_fuse.onnx')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants