04_PyTorch model training [weight initialization of Finetune]

In practical applications, we usually use the weight parameters of a model that has already been trained as the initialization parameters of our model,
Also known as Finetune , more broadly known as transfer learning. Finetune technology in transfer learning is essentially
Let our newly constructed model have a better initial value of weights.
The three-step process of finetune weight initialization, finetune is equivalent to initializing the model, and the process shares three steps:
The first step: save the model and have a pre-trained model;
Step 2: Load the model and take out the weights in the pre-trained model;
Step 3: Initialize, " put " the corresponding weights into the new model
1. Code
1) Load data
import torchvision.transforms as transforms
# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)

trainTransform = transforms.Compose([
    #重置图像分辨率为32x32
    transforms.Resize(32),
    #上下左右均填充4个像素,然后随机裁剪32*32
    transforms.RandomCrop(32, padding=4),
    #先对数据进行转置,将h*w*c变成c*h*w
    #然后将所有像素除以255,使得像素归一化为[0-1]
    transforms.ToTensor(),
    #对图像进行标准化
    normTransform
])

validTransform = transforms.Compose([
    transforms.ToTensor(),
    normTransform
])
import os
import sys
sys.path.append( "./../util/")  #工具类的相对位置
from torch.utils.data import DataLoader
from utils import MyDataset 



base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")

# 构建MyDataset实例
train_data = MyDataset(txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(txt_path=valid_txt_path, transform=validTransform)

#批次大小
train_bs = 16
valid_bs = 16

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=train_bs, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=valid_bs)

2) Define the network

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  #输入3 输出6 kernel5  3x32x32 -> 6x28x28
        self.pool1 = nn.MaxPool2d(2, 2) # kernel 2 stride 2   6x28x28 -> 6x14x14
        self.conv2 = nn.Conv2d(6, 16, 5) #输出6 输入16 kernel 5  6x14x14 -> 16x10x10
        self.pool2 = nn.MaxPool2d(2, 2) # 16x10x10 -> 16x5x5
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 16 * 5 * 5 -> 120
        self.fc2 = nn.Linear(120, 84) #120 -> 84
        self.fc3 = nn.Linear(84, 10)  #84 -> 10

    #conv1->relu->pool1->conv2->relu->pool2->fc1->relu->fc2->relu->fc3
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    # 定义权值初始化
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data) #正态分布
                if m.bias is not None:
                    m.bias.data.zero_()  #偏置全部归0
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)  # 权重全部归1
                m.bias.data.zero_()      # 偏置全部归0
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01) #正态分布初始化
                m.bias.data.zero_() #偏置全部归0
net = Net()     # 创建一个网络

3) Finetune weight initialization

#加载pkl
pretrained_dict = torch.load(r'E:\\pytorch_learning\\Data\\net_params.pkl')
# 获取当前网络的dict
net_state_dict = net.state_dict()
# 剔除不匹配的权值参数
pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}
# 更新新模型参数字典
net_state_dict.update(pretrained_dict_1)
# 将包含预训练模型参数的字典"放"到新模型中
net.load_state_dict(net_state_dict)

2. Effect

By printing the dictionary, we found that the key of the dictionary is the network layer, and the value is the corresponding weight parameter

 The next parameters are in the following chapters

Guess you like

Origin blog.csdn.net/zhang2362167998/article/details/128821698