The onnx model cuts off the nodes behind conv, sets the output layer name to be consistent with the last node name, and sets the output layer shape to be consistent with the output node.

The operators after the last convolutional layer of some models are not suitable for running in the inference engine. Cutting off the operators after conv can achieve better performance on the CPU.
Include:
1.Example of getting the shape of onnx intermediate node
2.Add onnx model output, set name, type, shape. Example
3.Edit onnx model example


Example of cutting off the green part
import onnx
import sys
import json
from onnx import shape_inference, TensorProto

if len(sys.argv) < 2:
    print('Usage: ' + sys.argv[0] + '<onnx_filename>')
    exit(-1)

onnx_file = sys.argv[1]

#Load ONNX model
model = onnx.load(onnx_file)

graph = model.graph

outputs = model.graph.output 
if(len(outputs)!=3):
    print("This isn't ScoreBoxKpt model!")
    quit()

output_list=["output0","output1","output2"]

for output in outputs:
    if output.name in score_box_kpt :
        print(f"output name: {output.name}")
    else:
        print("This isn't a fit model!")
        quit()

def getConvList(endName):
    stack=[]
    stack.append(endName)
    convList=[]
    while(len(stack)):
        name=stack.pop()
        for node in graph.node:
            if name in node.output :
                if node.op_type=="Conv":
                    if node.name not in convList :
                        convList.append(node.name)
                else: 
                    for input in node.input:
                        if input not in stack:
                            stack.insert(0, input)
    return convList

Conv0=getConvList(output_list[0])
Conv1=getConvList(output_list[1])
Conv2=getConvList(output_list[2])

def save2json(save_dict, name):
    if len(save_dict) == 0:
        print("this is nothing to save json")
        return None
    with open(name, 'w') as fp:
        #{'a': 'Runoob', 'b': 7}
        json.dump(save_dict, fp, sort_keys=False, indent=4, separators=(',', ': ')) #default=str

save_dict = {output_list[0]:scoreConv,
             output_list[1]:boxConv,
             output_list[2]:kptConv
            }

conv_list=Conv0+Conv1+Conv2

#Get the shape of onnx intermediate node.
output_dim_dic={}
inferred_onnx_model = shape_inference.infer_shapes(model)
inferred_graph = inferred_onnx_model.graph
inferred_value_info = inferred_graph.value_info
for node in graph.node:
    if node.name in conv_list:
        for value_info in inferred_value_info:
            if value_info.name==node.output[0]:
                output_dim_dic[node.name]=value_info.type.tensor_type;

#Delete the onnx node behind conv
# Find target node index
for name in conv_list:
    target_node = None
    for node in graph.node:
        if node.name == name:
            target_node=node
            break
    output_names = []
    for output in target_node.output:
        output_names.append(output)

    set1=set(output_names)
    del_node = []

    have_new_del_node = False
    while True:
        have_new_del_node = False
        for node in graph.node:
            if node.name in del_node:
                continue
            set2=set(node.input)
            if set1.intersection(set2): 
                output_names+=node.output         
                set1=set(output_names)
                del_node.append(node.name)
                have_new_del_node = True
        if have_new_del_node == False:
            break

    for node in graph.node:
        if node.name in del_node:
            print(f"1remove node {node.name}")
            model.graph.node.remove(node)

have_new_del_node = False
while True:
    have_new_del_node = False
    for node1 in graph.node:
        if node1.name in conv_list :
            continue
        set1=set(node1.output)
        to_delete =True
        for node2 in graph.node:
            set2=set(node2.input)
            if set1.intersection(set2): 
                to_delete = False
                break
        if to_delete == True:
            print(f"2remove node {node1.name}")
            model.graph.node.remove(node1)
            have_new_del_node=True
    if have_new_del_node == False :
        break

save_output_name=[]
for node in graph.node:
    if node.name in conv_list:
     #Add output layer
        output_info = onnx.helper.ValueInfoProto()
        node.output[0]=node.name
        output_info.name = node.output[0]
        for dim_value in output_dim_dic[node.name].shape.dim:
            output_info.type.tensor_type.shape.dim.extend([dim_value])
        output_info.type.tensor_type.elem_type = TensorProto.FLOAT
        print(output_info)
        graph.output.extend([output_info])
        save_output_name.append(node.output[0])

outputs = model.graph.output 
#Print node name
for output in outputs:
    if output.name  in save_output_name :
        continue
    model.graph.output.remove(output)
outputs = model.graph.output
#Print node name
for output in outputs:
    if output.name  in save_output_name :
        continue
    model.graph.output.remove(output)
# Save modified ONNX model
onnx.checker.check_model(model)
onnx.save(model, "backbone.onnx")
save2json(save_dict, 'conv_param.json'

Guess you like

Origin blog.csdn.net/soralaro/article/details/132621291