[PyTorch]Onnx模型格式的转换与应用

        相较于PyTorch默认的模型存储格式pth而言,onnx具有多端通用,方便部署的优点(据称能加快推理速度,但是未验证),本文将介绍如何使用onnx并将原有的pth权重转换为onnx。

一、配置环境

        在控制台中使用如下指令

pip install onnx
pip install onnxruntime

        随后在项目中引入环境

import onnx

二、将pth转换为onnx

        使用onnx自带的export函数即可,代码如下:

def Convert2Onnx(pth_Path,Onnx_Path,model):
    model_loader(model,pth_Path,torch.device('cpu'))
    input = torch.rand(1,3,224,224)    #需要调整为你的模型输入尺寸,包含batch项
    torch.onnx.export(model,input,Onnx_Path,input_names=['Inp'],output_names=['Outs'])

        其中,model_loader是一个封装好的自适应pth加载器,可以在部分权重不匹配的情况下加载pth文件,代码如下:

def model_loader(model,model_path,device):
    print('     开始从本地加载权重文件')
    model_dict      = model.state_dict()
    pretrained_dict = torch.load(model_path, map_location = device)
    load_key, no_load_key, temp_dict = [], [], {}
    for k, v in pretrained_dict.items():
        if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
            temp_dict[k] = v
            load_key.append(k)
        else:
            no_load_key.append(k)
    model_dict.update(temp_dict)
    model.load_state_dict(model_dict)

三、加载onnx模型

        同样适用onnx自带的load代码即可

model = onnx.load('onnx_model.onnx’)

四、onnx模型可视化

        访问网址:https://netron.app/,并在其中选择自己的onnx模型即可。

 五、其他相关操作

        1.模型相关信息的获取

# 检查模型是否完整
onnx.checker.check_model(model)

# 获取输出层信息
output = self.model.graph.output
print(output)

        2.模型的层编辑

import onnx
from onnx import helper

# 加载模型
model = onnx.load('converted_vig.onnx’)

# 创建中间节点:层名、数据类型、维度信息
prob_info = helper.make_tensor_value_info('layer1',onnx.TensorProto.FLOAT, [1, 3, 320, 280])

# 将构造好的中间节点插入到模型中
model.graph.output.insert(0, prob_info)

#保存新模型
onnx.save(model, 'onnx_model_new.onnx’)
#删除的节点item
model.graph.output.remove(item)

猜你喜欢

转载自blog.csdn.net/weixin_37878740/article/details/130617947