vgg16网络裁剪并加载模型参数

        主要是测试下模型裁剪后转onnx的问题。删除vgg16网络全连接层,加载预训练模型并重新保存模型参数,将该参数用于转onnx模型格式。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time        :2022/8/4 14:45
# @Author      :weiz
# @ProjectName :cbir
# @File        :vgg.py
# @Description :
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2


class VGG16(nn.Module):

    def __init__(self):
        super(VGG16, self).__init__()

        # 1 * 3 * 224 * 224
        self.conv1_1 = nn.Conv2d(3, 64, 3)  # conv1_1:1 * 64 * 222 * 222
        self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1))  # conv1_2:1 * 64 * 222* 222
        self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1))  # maxpool1: 1 * 64 * 112 * 112

        self.conv2_1 = nn.Conv2d(64, 128, 3)  # conv2_1:1 * 128 * 110 * 110
        self.conv2_2 = nn.Conv2d(128, 128, 3, padding=(1, 1))  # conv2_2:1 * 128 * 110 * 110
        self.maxpool2 = nn.MaxPool2d((2, 2), padding=(1, 1))  # maxpool2: 1 * 128 * 56 * 56

        self.conv3_1 = nn.Conv2d(128, 256, 3)  # conv3_1:1 * 256 * 54 * 54
        self.conv3_2 = nn.Conv2d(256, 256, 3, padding=(1, 1))  # conv3_2:1 * 256 * 54 * 54
        self.conv3_3 = nn.Conv2d(256, 256, 3, padding=(1, 1))  # conv3_3:1 * 256 * 54 * 54
        self.maxpool3 = nn.MaxPool2d((2, 2), padding=(1, 1))  # maxpool3:1 * 256 * 28 * 28

        self.conv4_1 = nn.Conv2d(256, 512, 3)  # conv4_1:1 * 512 * 26 * 26
        self.conv4_2 = nn.Conv2d(512, 512, 3, padding=(1, 1))  # conv4_2:1 * 512 * 26 * 26
        self.conv4_3 = nn.Conv2d(512, 512, 3, padding=(1, 1))  # conv4_3:1 * 512 * 26 * 26
        self.maxpool4 = nn.MaxPool2d((2, 2), padding=(1, 1))  # maxpool4:1 * 512 * 14 * 14

        self.conv5_1 = nn.Conv2d(512, 512, 3)  # conv5_1:1 * 512 * 12 * 12
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=(1, 1))  # conv5_2:1 * 512 * 12 * 12
        self.conv5_3 = nn.Conv2d(512, 512, 3, padding=(1, 1))  # conv5_3:1 * 512 * 12 * 12
        self.maxpool5 = nn.MaxPool2d((2, 2), padding=(1, 1))  # maxpool5:1 * 512 * 7 * 7
        # 1 * 512 * 1 * 1
        self.feature = torch.nn.AvgPool2d((7, 7), stride=(7, 7), padding=0, ceil_mode=False, count_include_pad=True)
        # self.feature = nn.AdaptiveAvgPool2d((7, 7))

        # view
        # self.fc1 = nn.Linear(512 * 7 * 7, 4096)
        # self.fc2 = nn.Linear(4096, 4096)
        # self.fc3 = nn.Linear(4096, 1000)
        # softmax 1 * 1 * 1000

    def forward(self, x):
        out = self.conv1_1(x)  # 222
        out = F.relu(out)
        out = self.conv1_2(out)  # 222
        out = F.relu(out)
        out = self.maxpool1(out)  # 112

        out = self.conv2_1(out)  # 110
        out = F.relu(out)
        out = self.conv2_2(out)  # 110
        out = F.relu(out)
        out = self.maxpool2(out)  # 56

        out = self.conv3_1(out)  # 54
        out = F.relu(out)
        out = self.conv3_2(out)  # 54
        out = F.relu(out)
        out = self.conv3_3(out)  # 54
        out = F.relu(out)
        out = self.maxpool3(out)  # 28

        out = self.conv4_1(out)  # 26
        out = F.relu(out)
        out = self.conv4_2(out)  # 26
        out = F.relu(out)
        out = self.conv4_3(out)  # 26
        out = F.relu(out)
        out = self.maxpool4(out)  # 14

        out = self.conv5_1(out)  # 12
        out = F.relu(out)
        out = self.conv5_2(out)  # 12
        out = F.relu(out)
        out = self.conv5_3(out)  # 12
        out = F.relu(out)
        out = self.maxpool5(out)  # 7

        out = self.feature(out)  # 1 * 512 * 1 * 1
        out = out.view(out.size(0), -1)   # 1 * 512

        # out = np.sum(out.data.cpu().numpy(), axis=0)
        # out /= np.sum(out)  # normalize

        # # 展平
        # out = out.view(in_size, -1)
        #
        # out = self.fc1(out)
        # out = F.relu(out)
        # out = self.fc2(out)
        # out = F.relu(out)
        # out = self.fc3(out)
        #
        # out = F.log_softmax(out, dim=1)

        return out

    def __call__(self, x):
        return self.forward(x)

    def get_name(self):
        return "vgg16"


def preprocessing(x):
    image = cv2.resize(x, (224, 224))

    means = np.array([103.939, 116.779, 123.68]) / 255.

    image = np.transpose(image, (2, 0, 1)) / 255.
    image[0] -= means[0]  # reduce B's mean
    image[1] -= means[1]  # reduce G's mean
    image[2] -= means[2]  # reduce R's mean
    image = np.expand_dims(image, axis=0)

    # if torch.cuda.is_available():
    #     inputs = torch.autograd.Variable(torch.from_numpy(image).cuda().float())
    # else:
    #     inputs = torch.autograd.Variable(torch.from_numpy(image).float())

    inputs = torch.autograd.Variable(torch.from_numpy(image).float())
    # print(inputs.shape)

    return inputs


def main():
    vgg = VGG16()
    # print(vgg.state_dict())
    vgg.load_state_dict(torch.load("./vgg_test.pth"), strict=False)
    # print(vgg.state_dict())

    image = cv2.imread("test_image/test_1.png")
    image = preprocessing(image)
    feature = vgg.forward(image)
    feature = np.sum(feature.data.cpu().numpy(), axis=0)
    feature /= np.sum(feature)  # normalize
    print(feature)


if __name__ == "__main__":
    pretrained = torch.load("C:\\Users\\weiz\\.cache\\torch\\hub\\checkpoints\\vgg16-397923af.pth")
    # print(pretrained)

    vgg = VGG16()
    vgg_dict = vgg.state_dict()

    pretrained_dict = {}
    for (k1, v1), (k2, v2) in zip(pretrained.items(), vgg_dict.items()):
        pretrained_dict[k2] = v1
    # pretrained_dict = {k: v for k, v in pretrained.items() if k in vgg_dict}
    # print(pretrained_dict)

    vgg_dict.update(pretrained_dict)
    vgg.load_state_dict(vgg_dict)

    torch.save(vgg.state_dict(), "vgg_dict_test.pth")  # 保存为只有模型参数格式
    torch.save(vgg, "vgg_test.pth")                    # 保存为既有有模型参数也有网络结构格式

    # main()

        pth转onnx代码

猜你喜欢

转载自blog.csdn.net/qq_31112205/article/details/126178192