[pytorch版]基于ResNet50的分类模型转onnx
基于ResNet50的分类模型转onnx
- 依赖环境
import torch
import torch.onnx
import argparse
from torchvision import datasets, models, transforms
import torch.nn as nn
import cv2
import numpy as np
from PIL import Image
- 核心代码
model_ft = models.resnet50()
num_fits = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_fits, 4) # 替换最后一个全连接层
# get device
device = torch.device("cpu:0")#默认CPU推理
model_ft = model_ft.to(device)
## 加载模型
model_ft.load_state_dict(torch.load(checkpoint_path))
for name in model_ft.state_dict():
print(name)
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
inference_transform = transforms.Compose([
transforms.Resize(512),
transforms.CenterCrop(448),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
dummy_input = np.ones((512, 512, 3), dtype=np.uint8)
dummy_input = Image.fromarray(cv2.cvtColor(dummy_input, cv2.COLOR_BGR2RGB))
img_tensor = inference_transform(dummy_input)
img_tensor.unsqueeze_(0)
torch.onnx.export(model_ft, img_tensor, save_onnx_path)
- 完整代码
import torch
import torch.onnx
import argparse
from torchvision import datasets, models, transforms
import torch.nn as nn
import cv2
import numpy as np
from PIL import Image
# 功能:把基于 resnet50 的pth模型转为onnx模型
def model2onnx(checkpoint_path, save_onnx_path):
model_ft = models.resnet50()
num_fits = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_fits, 4) # 替换最后一个全连接层
# get device
device = torch.device("cpu:0")#默认CPU推理
model_ft = model_ft.to(device)
## 加载模型
model_ft.load_state_dict(torch.load(checkpoint_path))
for name in model_ft.state_dict():
print(name)
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
inference_transform = transforms.Compose([
transforms.Resize(512),
transforms.CenterCrop(448),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
dummy_input = np.ones((512, 512, 3), dtype=np.uint8)
dummy_input = Image.fromarray(cv2.cvtColor(dummy_input, cv2.COLOR_BGR2RGB))
img_tensor = inference_transform(dummy_input)
img_tensor.unsqueeze_(0)
torch.onnx.export(model_ft, img_tensor, save_onnx_path)
print('onnx saved to ', save_onnx_path)
if __name__ == '__main__':
# python3 model2onnx.py --checkpoint=./model/text_direction.pth --save_onnx_path=./model/text_direction.onnx
parser = argparse.ArgumentParser(description="text_direction model2onnx")
parser.add_argument('--checkpoint', type=str, default='./model/text_direction.pth',
help='the path to your checkpoints')
parser.add_argument('--save_onnx_path', type=str, default='./model/text_direction.onnx',
help='save_onnx_path')
args = parser.parse_args()
model2onnx(args.checkpoint, args.save_onnx_path)
- 用法
python3 model2onnx.py --checkpoint=./model/text_direction.pth --save_onnx_path=./model/text_direction.onnx
最终输出结果保存到:./model/text_direction.onnx