[PyTorch] Conversion and application of Onnx model format

        Compared with PyTorch's default model storage format pth, onnx has the advantages of multi-terminal and convenient deployment (it is said to speed up reasoning, but it has not been verified). This article will introduce how to use onnx and convert the original pth weight to onnx.

1. Configuration environment

        Use the following command in the console

pip install onnx
pip install onnxruntime

        Then import the environment into the project

import onnx

2. Convert pth to onnx

        Use the export function that comes with onnx, the code is as follows:

def Convert2Onnx(pth_Path,Onnx_Path,model):
    model_loader(model,pth_Path,torch.device('cpu'))
    input = torch.rand(1,3,224,224)    #需要调整为你的模型输入尺寸,包含batch项
    torch.onnx.export(model,input,Onnx_Path,input_names=['Inp'],output_names=['Outs'])

        Among them, model_loader is a packaged adaptive pth loader, which can load pth files when some weights do not match. The code is as follows:

def model_loader(model,model_path,device):
    print('     开始从本地加载权重文件')
    model_dict      = model.state_dict()
    pretrained_dict = torch.load(model_path, map_location = device)
    load_key, no_load_key, temp_dict = [], [], {}
    for k, v in pretrained_dict.items():
        if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
            temp_dict[k] = v
            load_key.append(k)
        else:
            no_load_key.append(k)
    model_dict.update(temp_dict)
    model.load_state_dict(model_dict)

3. Load the onnx model

        The same applies to the load code that comes with onnx

model = onnx.load('onnx_model.onnx’)

Four, onnx model visualization

        Visit the website: https://netron.app/ , and choose your own onnx model in it.

 5. Other related operations

        1. Acquisition of model-related information

# 检查模型是否完整
onnx.checker.check_model(model)

# 获取输出层信息
output = self.model.graph.output
print(output)

        2. Layer editing of the model

import onnx
from onnx import helper

# 加载模型
model = onnx.load('converted_vig.onnx’)

# 创建中间节点:层名、数据类型、维度信息
prob_info = helper.make_tensor_value_info('layer1',onnx.TensorProto.FLOAT, [1, 3, 320, 280])

# 将构造好的中间节点插入到模型中
model.graph.output.insert(0, prob_info)

#保存新模型
onnx.save(model, 'onnx_model_new.onnx’)
#删除的节点item
model.graph.output.remove(item)

Guess you like

Origin blog.csdn.net/weixin_37878740/article/details/130617947