探索发现:tensorflow转onnx时,输入无符号shape的情况解决。

一、前言

如标题,有几次朋友遇到这种情况,所以我想看看能不能直接更改 onnx 模型的 input shape 来解决这种问题。这种情况目前全发生在 tensorflow -> onnx 过程中,pytorch 由于有 onnx 的导出官方 api, 所以没有此烦心事。

二、代码与讲解

方法有两种,具体如下:

1、简化前修改输入 shape

先看看未修改前输入的无符号shape:
在这里插入图片描述
在这里插入图片描述
很明显,本该是1的地方出现了几个乱七八糟的符号。但是怎么做呢?我主要参考了这个

from onnx.tools import update_model_dims
import onnx 


def renew_shape():
    onnx_file = "xxxx.onnx"  # onnx path
    model = onnx.load(onnx_file)
    print(model.graph.input[0].name)  # 可用这句代码预先查看输入节点名
    print(model.graph.output[0].name)  # 同上
    
    variable_length_model = update_model_dims.update_inputs_outputs_dims(model, {
    
    'input_name': [1, 64, 64, 3]}, {
    
    'output_name': [1, 3]})  # 列表中的 shape 改成自己对应的
    print(variable_length_model.graph.input[0])  # 查看更改后的输入输出 shape
    print(variable_length_model.graph.output[0])
    onnx.save(variable_length_model, "xxxx.onnx")  # 保存更改后的模型

修改后查看网络结构:
在这里插入图片描述
在这里插入图片描述
好了,修改成功,其他网络结构不变,下步就是简化模型了。

2、简化时指定输入shape

之所以要修改输入 shape ,其实主要是简化时会报错,当然如果在简化时直接指明 shape 的话,也没什么问题,如下。

import onnx
from onnxsim import simplify


onnx_file = "xxxx.onnx"
sim_onnx_path = "xxxx_sim.onnx"
model = onnx.load(onnx_file)  # load onnx model
onnx.checker.check_model(model)  # check onnx model
#print(onnx.helper.printable_graph(model.graph))  # print a human readable representation of the graph
#print('Export complete. ONNX model saved to %s\nView with https://github.com/lutzroeder/netron' % export_onnx_file)

model_simp, check = simplify(model, input_shapes={
    
    "input_name":[1, 64, 64, 3]})  # 列表里的 shape 改成自己对应的
onnx.save(model_simp, sim_onnx_path)
print("Simplify onnx done !")

三、后语

其实每当出现上述问题,难搞的还在后面,上面说的还只是小问题,毕竟tfboy不好当。仓促写成,如有遗漏、错处,还请指出,谢谢。

猜你喜欢

转载自blog.csdn.net/tangshopping/article/details/111874321