In practical applications, we usually use the weight parameters of a model that has already been trained as the initialization parameters of our model,
Also known as
Finetune
, more broadly known as transfer learning.
Finetune
technology in transfer learning is essentially
Let our newly constructed model have a better initial value of weights.
The three-step process of finetune
weight initialization,
finetune
is equivalent to initializing the model, and the process shares three steps:
The first step: save the model and have a pre-trained model;
Step 2: Load the model and take out the weights in the pre-trained model;
Step 3: Initialize, "
put
"
the corresponding weights
into the new model
1. Code
1) Load data
import torchvision.transforms as transforms
# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
#重置图像分辨率为32x32
transforms.Resize(32),
#上下左右均填充4个像素,然后随机裁剪32*32
transforms.RandomCrop(32, padding=4),
#先对数据进行转置,将h*w*c变成c*h*w
#然后将所有像素除以255,使得像素归一化为[0-1]
transforms.ToTensor(),
#对图像进行标准化
normTransform
])
validTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
])
import os
import sys
sys.path.append( "./../util/") #工具类的相对位置
from torch.utils.data import DataLoader
from utils import MyDataset
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")
# 构建MyDataset实例
train_data = MyDataset(txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(txt_path=valid_txt_path, transform=validTransform)
#批次大小
train_bs = 16
valid_bs = 16
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=train_bs, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=valid_bs)
2) Define the network
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) #输入3 输出6 kernel5 3x32x32 -> 6x28x28
self.pool1 = nn.MaxPool2d(2, 2) # kernel 2 stride 2 6x28x28 -> 6x14x14
self.conv2 = nn.Conv2d(6, 16, 5) #输出6 输入16 kernel 5 6x14x14 -> 16x10x10
self.pool2 = nn.MaxPool2d(2, 2) # 16x10x10 -> 16x5x5
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 16 * 5 * 5 -> 120
self.fc2 = nn.Linear(120, 84) #120 -> 84
self.fc3 = nn.Linear(84, 10) #84 -> 10
#conv1->relu->pool1->conv2->relu->pool2->fc1->relu->fc2->relu->fc3
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义权值初始化
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data) #正态分布
if m.bias is not None:
m.bias.data.zero_() #偏置全部归0
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1) # 权重全部归1
m.bias.data.zero_() # 偏置全部归0
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01) #正态分布初始化
m.bias.data.zero_() #偏置全部归0
net = Net() # 创建一个网络
3) Finetune weight initialization
#加载pkl
pretrained_dict = torch.load(r'E:\\pytorch_learning\\Data\\net_params.pkl')
# 获取当前网络的dict
net_state_dict = net.state_dict()
# 剔除不匹配的权值参数
pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}
# 更新新模型参数字典
net_state_dict.update(pretrained_dict_1)
# 将包含预训练模型参数的字典"放"到新模型中
net.load_state_dict(net_state_dict)
2. Effect
By printing the dictionary, we found that the key of the dictionary is the network layer, and the value is the corresponding weight parameter
The next parameters are in the following chapters