模型部署之模型转换

在模型落地应用时,我们往往需要先对模型进行格式转换,本篇博客记录从.pth文件转换为.onnx文件的具体流程。

一、模型结构定义

在模型训练好之后我们有两种保存方式,一种就是将模型的结构与模型的参数一起进行保存,但是由于后续工程的持续改进,这种方法往往并不是很实用,大部分工程采用了第二种方法,也就是只保存模型中的参数部分,由于只有模型的参数,我们在转换模型之前我们需要将模型的结构进行实例化。这也是模型转换的第一步
我们可以在原工程中找到get_model这一接口,但是如果对工程并不熟悉的话这一操作往往会比较困难,所以我们采用了比较简单的方式,也就是将模型的结构文件直接进行实例化。举个例子,有如下模型结构:

import torch
import torch.nn as nn
import torch.nn.functional as F


class cde(nn.Module):

    def __init__(self):
        super(cde, self).__init__()

        self.relu = nn.ReLU(inplace=True)

        number_f = 32
        self.e_conv1 = nn.Conv2d(3, number_f, 3, 1, 1, bias=True)
        self.e_conv2 = nn.Conv2d(number_f, number_f, 3, 1, 1, bias=True)
        self.e_conv3 = nn.Conv2d(number_f, number_f, 3, 1, 1, bias=True)
        self.e_conv4 = nn.Conv2d(number_f, number_f, 3, 1, 1, bias=True)
        self.e_conv5 = nn.Conv2d(number_f * 2, number_f, 3, 1, 1, bias=True)
        self.e_conv6 = nn.Conv2d(number_f * 2, number_f, 3, 1, 1, bias=True)
        self.e_conv7 = nn.Conv2d(number_f * 2, 24, 3, 1, 1, bias=True)

        self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self, x):
        x1 = self.relu(self.e_conv1(x))
        # p1 = self.maxpool(x1)
        x2 = self.relu(self.e_conv2(x1))
        # p2 = self.maxpool(x2)
        x3 = self.relu(self.e_conv3(x2))
        # p3 = self.maxpool(x3)
        x4 = self.relu(self.e_conv4(x3))

        x5 = self.relu(self.e_conv5(torch.cat([x3, x4], 1)))
        # x5 = self.upsample(x5)
        x6 = self.relu(self.e_conv6(torch.cat([x2, x5], 1)))

        x_r = F.tanh(self.e_conv7(torch.cat([x1, x6], 1)))
        r1, r2, r3, r4, r5, r6, r7, r8 = torch.split(x_r, 3, dim=1)

        x = x + r1 * (torch.pow(x, 2) - x)
        x = x + r2 * (torch.pow(x, 2) - x)
        x = x + r3 * (torch.pow(x, 2) - x)
        enhance_image_1 = x + r4 * (torch.pow(x, 2) - x)
        x = enhance_image_1 + r5 * (torch.pow(enhance_image_1, 2) - enhance_image_1)
        x = x + r6 * (torch.pow(x, 2) - x)
        x = x + r7 * (torch.pow(x, 2) - x)
        enhance_image = x + r8 * (torch.pow(x, 2) - x)
        r = torch.cat([r1, r2, r3, r4, r5, r6, r7, r8], 1)
        return enhance_image_1, enhance_image, r


我们只需要将本结构保存为一个model.py文件,便于后续的实用。接下来我们找到需要转换模型的权重就可以开始转换了
转换代码如下:

import torch
import model

# 实例化模型的结构部分
DCE_net = model.cde()
# 训练之后保存的模型权重
a = "./1.pth"
# 加载模型权重文件
DCE_net.load_state_dict(torch.load(a))

# 定义导出文件的路径以及名称
onnx_path = './2.onnx'
# 定义模型的静态输入,也可以指定动态输入,详细操作见官网
input = torch.randn(1, 3, 640, 640) 
torch.onnx.export(DCE_net, args=input, f=onnx_path, export_params=False, verbose=True, opset_version=11)  # 指定模型的输入,以及onnx的输出路径

运行之后就可以将模型进行转换完毕了!

猜你喜欢

转载自blog.csdn.net/qq_52302919/article/details/124962522