pth转onnx的三种情况

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time        :2022/8/3 16:19
# @Author      :weiz
# @ProjectName :cbir
# @File        :pth2onnx.py
# @Description :
from vgg import *
import torch


def pth2onnx_all(input_shape, model_path, onnx_path):
    """
    有参数,有模型结构
    """
    model = torch.load(model_path, map_location=lambda storage, loc: storage)

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")

    input_shape = torch.randn(input_shape[0], input_shape[1], input_shape[2], input_shape[3], device=device)
    torch.onnx.export(model, input_shape, onnx_path, opset_version=9, verbose=False,
                      input_names=["input"], output_names=["output"])


def pth2onnx(model_net, input_shape, model_path, onnx_path):
    """
    只有模型参数转onnx需要网络结构(model_net),官方推荐这种
    """
    model_statedict = torch.load(model_path, map_location=lambda storage, loc: storage)
    model_net.load_state_dict(model_statedict)
    model_net.eval()  # 测试,看是否报错

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    input_shape = torch.randn(input_shape[0], input_shape[1], input_shape[2], input_shape[3], device=device)
    torch.onnx.export(model_net, input_shape, onnx_path, opset_version=9, verbose=False,
                      input_names=["input"], output_names=["output"])


def pth2onnx_dynamic(model_net, input_shape, model_path, onnx_path):
    """
    输入输出不固定,需要使用dynamic_axes参数
    """
    model_statedict = torch.load(model_path, map_location=lambda storage, loc: storage)
    model_net.load_state_dict(model_statedict)
    model_net.eval()  # 测试,看是否报错

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    input_shape = torch.randn(input_shape[0], input_shape[1], input_shape[2], input_shape[3], device=device)

    # 动态输入输出:batch_size in_width int_height都可以动态
    # dynameic_input = {"input": {0: 'batch_size', 2 : 'in_width', 3: 'int_height'},
    #                   "output": {0: 'batch_size', 2: 'out_width', 3:'out_height'}}
    dynameic_input = {"input": {2 : 'in_width', 3: 'int_height'}}

    torch.onnx.export(model_net, input_shape, onnx_path, opset_version=9, verbose=False,
                      input_names=["input"], output_names=["output"], dynamic_axes=dynameic_input)


if __name__ == "__main__":
    vgg = VGG16()
    pth2onnx(vgg, [1, 3, 224, 224], "./vgg_dict_test.pth", "./vgg_no_full.onnx")

    pth2onnx_dynamic(vgg, [1, 3, 224, 224], "./vgg_dict_test.pth", "./vgg_no_full_dynamic.onnx")

    pth2onnx_all([1, 3, 224, 224], "./vgg_test.pth", "./vgg16_no_ful.onnx")

猜你喜欢

转载自blog.csdn.net/qq_31112205/article/details/126180272