Principalmente para probar el problema de convertir el modelo a onnx después del recorte. Elimine la capa completamente conectada de la red vgg16, cargue el modelo previamente entrenado y vuelva a guardar los parámetros del modelo, y use los parámetros para convertir el formato del modelo onnx.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time :2022/8/4 14:45
# @Author :weiz
# @ProjectName :cbir
# @File :vgg.py
# @Description :
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
class VGG16(nn.Module):
def __init__(self):
super(VGG16, self).__init__()
# 1 * 3 * 224 * 224
self.conv1_1 = nn.Conv2d(3, 64, 3) # conv1_1:1 * 64 * 222 * 222
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1)) # conv1_2:1 * 64 * 222* 222
self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1)) # maxpool1: 1 * 64 * 112 * 112
self.conv2_1 = nn.Conv2d(64, 128, 3) # conv2_1:1 * 128 * 110 * 110
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=(1, 1)) # conv2_2:1 * 128 * 110 * 110
self.maxpool2 = nn.MaxPool2d((2, 2), padding=(1, 1)) # maxpool2: 1 * 128 * 56 * 56
self.conv3_1 = nn.Conv2d(128, 256, 3) # conv3_1:1 * 256 * 54 * 54
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # conv3_2:1 * 256 * 54 * 54
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # conv3_3:1 * 256 * 54 * 54
self.maxpool3 = nn.MaxPool2d((2, 2), padding=(1, 1)) # maxpool3:1 * 256 * 28 * 28
self.conv4_1 = nn.Conv2d(256, 512, 3) # conv4_1:1 * 512 * 26 * 26
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # conv4_2:1 * 512 * 26 * 26
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # conv4_3:1 * 512 * 26 * 26
self.maxpool4 = nn.MaxPool2d((2, 2), padding=(1, 1)) # maxpool4:1 * 512 * 14 * 14
self.conv5_1 = nn.Conv2d(512, 512, 3) # conv5_1:1 * 512 * 12 * 12
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # conv5_2:1 * 512 * 12 * 12
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # conv5_3:1 * 512 * 12 * 12
self.maxpool5 = nn.MaxPool2d((2, 2), padding=(1, 1)) # maxpool5:1 * 512 * 7 * 7
# 1 * 512 * 1 * 1
self.feature = torch.nn.AvgPool2d((7, 7), stride=(7, 7), padding=0, ceil_mode=False, count_include_pad=True)
# self.feature = nn.AdaptiveAvgPool2d((7, 7))
# view
# self.fc1 = nn.Linear(512 * 7 * 7, 4096)
# self.fc2 = nn.Linear(4096, 4096)
# self.fc3 = nn.Linear(4096, 1000)
# softmax 1 * 1 * 1000
def forward(self, x):
out = self.conv1_1(x) # 222
out = F.relu(out)
out = self.conv1_2(out) # 222
out = F.relu(out)
out = self.maxpool1(out) # 112
out = self.conv2_1(out) # 110
out = F.relu(out)
out = self.conv2_2(out) # 110
out = F.relu(out)
out = self.maxpool2(out) # 56
out = self.conv3_1(out) # 54
out = F.relu(out)
out = self.conv3_2(out) # 54
out = F.relu(out)
out = self.conv3_3(out) # 54
out = F.relu(out)
out = self.maxpool3(out) # 28
out = self.conv4_1(out) # 26
out = F.relu(out)
out = self.conv4_2(out) # 26
out = F.relu(out)
out = self.conv4_3(out) # 26
out = F.relu(out)
out = self.maxpool4(out) # 14
out = self.conv5_1(out) # 12
out = F.relu(out)
out = self.conv5_2(out) # 12
out = F.relu(out)
out = self.conv5_3(out) # 12
out = F.relu(out)
out = self.maxpool5(out) # 7
out = self.feature(out) # 1 * 512 * 1 * 1
out = out.view(out.size(0), -1) # 1 * 512
# out = np.sum(out.data.cpu().numpy(), axis=0)
# out /= np.sum(out) # normalize
# # 展平
# out = out.view(in_size, -1)
#
# out = self.fc1(out)
# out = F.relu(out)
# out = self.fc2(out)
# out = F.relu(out)
# out = self.fc3(out)
#
# out = F.log_softmax(out, dim=1)
return out
def __call__(self, x):
return self.forward(x)
def get_name(self):
return "vgg16"
def preprocessing(x):
image = cv2.resize(x, (224, 224))
means = np.array([103.939, 116.779, 123.68]) / 255.
image = np.transpose(image, (2, 0, 1)) / 255.
image[0] -= means[0] # reduce B's mean
image[1] -= means[1] # reduce G's mean
image[2] -= means[2] # reduce R's mean
image = np.expand_dims(image, axis=0)
# if torch.cuda.is_available():
# inputs = torch.autograd.Variable(torch.from_numpy(image).cuda().float())
# else:
# inputs = torch.autograd.Variable(torch.from_numpy(image).float())
inputs = torch.autograd.Variable(torch.from_numpy(image).float())
# print(inputs.shape)
return inputs
def main():
vgg = VGG16()
# print(vgg.state_dict())
vgg.load_state_dict(torch.load("./vgg_test.pth"), strict=False)
# print(vgg.state_dict())
image = cv2.imread("test_image/test_1.png")
image = preprocessing(image)
feature = vgg.forward(image)
feature = np.sum(feature.data.cpu().numpy(), axis=0)
feature /= np.sum(feature) # normalize
print(feature)
if __name__ == "__main__":
pretrained = torch.load("C:\\Users\\weiz\\.cache\\torch\\hub\\checkpoints\\vgg16-397923af.pth")
# print(pretrained)
vgg = VGG16()
vgg_dict = vgg.state_dict()
pretrained_dict = {}
for (k1, v1), (k2, v2) in zip(pretrained.items(), vgg_dict.items()):
pretrained_dict[k2] = v1
# pretrained_dict = {k: v for k, v in pretrained.items() if k in vgg_dict}
# print(pretrained_dict)
vgg_dict.update(pretrained_dict)
vgg.load_state_dict(vgg_dict)
torch.save(vgg.state_dict(), "vgg_dict_test.pth") # 保存为只有模型参数格式
torch.save(vgg, "vgg_test.pth") # 保存为既有有模型参数也有网络结构格式
# main()
pth al código onnx