04_Entrenamiento del modelo PyTorch [inicialización del peso de Finetune]

En aplicaciones prácticas, solemos utilizar los parámetros de peso de un modelo que ya ha sido entrenado como parámetros de inicialización de nuestro modelo,
También conocido como Finetune , más ampliamente conocido como transferencia de aprendizaje. La tecnología Finetune en el aprendizaje por transferencia es esencialmente
Deje que nuestro modelo recién construido tenga un mejor valor inicial de pesos.
El proceso de tres pasos de inicialización del peso de ajuste fino , el ajuste fino es equivalente a inicializar el modelo, y el proceso comparte tres pasos:
El primer paso: guardar el modelo y tener un modelo pre-entrenado;
Paso 2: Cargue el modelo y saque los pesos en el modelo pre-entrenado;
Paso 3: Inicializar, " poner " los pesos correspondientes en el nuevo modelo
1. Código
1) Cargar datos
import torchvision.transforms as transforms
# 数据预处理设置
normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)

trainTransform = transforms.Compose([
    #重置图像分辨率为32x32
    transforms.Resize(32),
    #上下左右均填充4个像素,然后随机裁剪32*32
    transforms.RandomCrop(32, padding=4),
    #先对数据进行转置,将h*w*c变成c*h*w
    #然后将所有像素除以255,使得像素归一化为[0-1]
    transforms.ToTensor(),
    #对图像进行标准化
    normTransform
])

validTransform = transforms.Compose([
    transforms.ToTensor(),
    normTransform
])
import os
import sys
sys.path.append( "./../util/")  #工具类的相对位置
from torch.utils.data import DataLoader
from utils import MyDataset 



base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")

# 构建MyDataset实例
train_data = MyDataset(txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(txt_path=valid_txt_path, transform=validTransform)

#批次大小
train_bs = 16
valid_bs = 16

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=train_bs, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=valid_bs)

2) Definir la red

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  #输入3 输出6 kernel5  3x32x32 -> 6x28x28
        self.pool1 = nn.MaxPool2d(2, 2) # kernel 2 stride 2   6x28x28 -> 6x14x14
        self.conv2 = nn.Conv2d(6, 16, 5) #输出6 输入16 kernel 5  6x14x14 -> 16x10x10
        self.pool2 = nn.MaxPool2d(2, 2) # 16x10x10 -> 16x5x5
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 16 * 5 * 5 -> 120
        self.fc2 = nn.Linear(120, 84) #120 -> 84
        self.fc3 = nn.Linear(84, 10)  #84 -> 10

    #conv1->relu->pool1->conv2->relu->pool2->fc1->relu->fc2->relu->fc3
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    # 定义权值初始化
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data) #正态分布
                if m.bias is not None:
                    m.bias.data.zero_()  #偏置全部归0
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)  # 权重全部归1
                m.bias.data.zero_()      # 偏置全部归0
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01) #正态分布初始化
                m.bias.data.zero_() #偏置全部归0
net = Net()     # 创建一个网络

3) inicialización de peso Finetune

#加载pkl
pretrained_dict = torch.load(r'E:\\pytorch_learning\\Data\\net_params.pkl')
# 获取当前网络的dict
net_state_dict = net.state_dict()
# 剔除不匹配的权值参数
pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}
# 更新新模型参数字典
net_state_dict.update(pretrained_dict_1)
# 将包含预训练模型参数的字典"放"到新模型中
net.load_state_dict(net_state_dict)

2. Efecto

Al imprimir el diccionario, encontramos que la clave del diccionario es la capa de red y el valor es el parámetro de peso correspondiente

 Los siguientes parámetros están en los siguientes capítulos.

Supongo que te gusta

Origin blog.csdn.net/zhang2362167998/article/details/128821698
Recomendado
Clasificación