Exploration found: when tensorflow is transferred to onnx, the situation of inputting unsigned shape is solved.

I. Introduction

As the title, several friends have encountered this situation, so I want to see if I can directly change onnxthe model input shapeto solve this problem. This kind of situation is currently happening in tensorflow -> onnxthe process . Since pytorch has the official API exported by onnx, there is no such trouble.

2. Code and explanation

There are two methods, as follows:

1. Modify the input shape before simplifying

First look at the unsigned shape input before modification:
insert image description here
insert image description here
Obviously, there are several messy symbols where it should be 1. But how? I mainly refer to this .

from onnx.tools import update_model_dims
import onnx 


def renew_shape():
    onnx_file = "xxxx.onnx"  # onnx path
    model = onnx.load(onnx_file)
    print(model.graph.input[0].name)  # 可用这句代码预先查看输入节点名
    print(model.graph.output[0].name)  # 同上
    
    variable_length_model = update_model_dims.update_inputs_outputs_dims(model, {
    
    'input_name': [1, 64, 64, 3]}, {
    
    'output_name': [1, 3]})  # 列表中的 shape 改成自己对应的
    print(variable_length_model.graph.input[0])  # 查看更改后的输入输出 shape
    print(variable_length_model.graph.output[0])
    onnx.save(variable_length_model, "xxxx.onnx")  # 保存更改后的模型

Check the network structure after modification:
insert image description here
insert image description here
Well, the modification is successful, and the other network structures remain unchanged. The next step is to simplify the model.

2. Specify the input shape when simplifying

The reason for modifying the input shape is that an error will be reported during simplification. Of course, if the shape is directly specified during simplification, there is no problem, as follows.

import onnx
from onnxsim import simplify


onnx_file = "xxxx.onnx"
sim_onnx_path = "xxxx_sim.onnx"
model = onnx.load(onnx_file)  # load onnx model
onnx.checker.check_model(model)  # check onnx model
#print(onnx.helper.printable_graph(model.graph))  # print a human readable representation of the graph
#print('Export complete. ONNX model saved to %s\nView with https://github.com/lutzroeder/netron' % export_onnx_file)

model_simp, check = simplify(model, input_shapes={
    
    "input_name":[1, 64, 64, 3]})  # 列表里的 shape 改成自己对应的
onnx.save(model_simp, sim_onnx_path)
print("Simplify onnx done !")

Three, postscript

In fact, whenever the above-mentioned problems occur, the difficult part is still to come. The above-mentioned problems are only minor problems. After all, it is not easy to be a tfboy. Written in a hurry, if there are any omissions or mistakes, please point them out, thank you.

Guess you like

Origin blog.csdn.net/tangshopping/article/details/111874321