24模型微调(finetune)

一、Transfer Learning & Model Finetune

1.1 Transfer Learning

Transfer Learning:机器学习分支,研究源域(source domain)的知识如何应用到目标域(targetdomain)
在这里插入图片描述

传统的机器学习:
对不同的任务分别训练学习得到不同的learning system,即模型,如上图有三个不同任务,就得到三个不同的模型

迁移学习:
先对源任务进行学习,得到知识,然后在目标任务中,会使用再源任务上学习得到的知识来学习训练模型,也就是说该模型不仅用到了target tasks,也用到了source tasks

1.2 Model Finetune

1.2.1 Model Finetune概念

Model Finetune:模型的迁移学习在这里插入图片描述
模型微调:
模型微调就是一个迁移学习的过程,模型中训练学习得到的权值,就是迁移学习中所谓的知识,而这些知识是可以进行迁移的,把这些知识迁移到新任务中,这就完成了迁移学习

微调的原因:
在新任务中,数据量太小,不足以去训练一个较大的模型,从而选择Model Finetune去辅助训练一个较好的模型,使得训练更快

卷积神经网络的迁移:
在这里插入图片描述
将卷积神经网络分成两部分:features extractor + classifier

  • features extractor:模型的共性部分,通常对其进行保留
  • classifier:根据不同任务要求对输出层进行finetune

1.2.2 Model Finetune步骤

在这里插入图片描述
Model Finetune:
先进行模型微调,加载模型参数,并根据任务要求修改模型,此过程称预训练,然后进行正式训练,此时要注意预训练的参数的保持,具体步骤和方法如下

模型微调步骤:

  1. 获取预训练模型参数
  2. 加载模型( load_state_dict)
  3. 修改输出层

模型微调训练方法:

  • 固定预训练的参数,两种方法:
    • requires_grad =False
    • lr=0
  • Features Extractor部分设置较小学习率( params_group)

说明:
优化器中可以管理不同的参数组,这样就可以为不同的参数组设置不同的超参数,对Features Extractor部分设置较小学习率

二、Pytorch中的Finetune

2.1 Model Finetune实例

在这里插入图片描述
数据: https://download.pytorch.org/tutorial/hymenoptera_data.zip
模型: https://download.pytorch.org/models/resnet18-5c106cde.pth

2.1.1 目录结构

在这里插入图片描述
模型和数据的存放位置如上图所示

2.1.1 代码详解

my_dataset.py

# -*- coding: utf-8 -*-
import os
import random
from PIL import Image
from torch.utils.data import Dataset

random.seed(1)
rmb_label = {"1": 0, "100": 1}


class AntsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.label_name = {"ants": 0, "bees": 1}
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img,label

    def __len__(self):
        return len(self.data_info)

    def get_img_info(self, data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = self.label_name[sub_dir]
                    data_info.append((path_img, int(label)))

        if len(data_info) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(data_dir))
        return data_info

common_tools.py

# -*- coding: utf-8 -*-

import torch
import random
import numpy as np
from PIL import Image
import torchvision.transforms as transforms


def transform_invert(img_, transform_train):
    """
    将data 进行反transfrom操作
    :param img_: tensor
    :param transform_train: torchvision.transforms
    :return: PIL image
    """
    if 'Normalize' in str(transform_train):
        norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
        mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
        std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
        img_.mul_(std[:, None, None]).add_(mean[:, None, None])

    img_ = img_.transpose(0, 2).transpose(0, 1)  # C*H*W --> H*W*C
    if 'ToTensor' in str(transform_train):
        img_ = img_.detach().numpy() * 255

    if img_.shape[2] == 3:
        img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
    elif img_.shape[2] == 1:
        img_ = Image.fromarray(img_.astype('uint8').squeeze())
    else:
        raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )

    return img_


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

finetune_resnet18.py

# -*- coding: utf-8 -*-

import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from tools.my_dataset import AntsDataset
from tools.common_tools import set_seed
import torchvision.models as models
import torchvision
BASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设置模型训练的设备
print("use device :{}".format(device))

set_seed(1)  # 设置随机种子
label_name = {"ants": 0, "bees": 1}

# 参数设置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 7


# ============================ step 1/5 数据 ============================
data_dir = os.path.join(BASEDIR, "..", "..", "data/hymenoptera_data")
train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = AntsDataset(data_dir=train_dir, transform=train_transform)
valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

# 1/3 构建模型
resnet18_ft = models.resnet18()

# 2/3 加载参数
# flag = 0
flag = 1
if flag:
    path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
    state_dict_load = torch.load(path_pretrained_model)
    resnet18_ft.load_state_dict(state_dict_load)

# 法1 : 冻结卷积层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
    for param in resnet18_ft.parameters():
        param.requires_grad = False
    print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))


# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features            # 从原始的resnet18从获取输入的结点数
resnet18_ft.fc = nn.Linear(num_ftrs, classes)


resnet18_ft.to(device)        # 将模型迁移到设置的设备上
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
# 法2 : conv 小学习率
flag = 0
# flag = 1
if flag:
    # 划分模型参数为两个部分:resnet18_ft.fc.parameters()和base_params
    fc_params_id = list(map(id, resnet18_ft.fc.parameters()))     # 返回的是parameters的 内存地址
    base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())

    optimizer = optim.SGD([
        {'params': base_params, 'lr': LR*0.1},   # 0
        {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)

else:
    optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)               # 选择优化器

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)     # 设置学习率下降策略


# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    resnet18_ft.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)   # 训练数据也要放到设置的设备上
        outputs = resnet18_ft(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().cpu().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

            # if flag_m1:
            # print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...]))

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        resnet18_ft.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = resnet18_ft(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().cpu().sum().numpy()

                loss_val += loss.item()

            loss_val_mean = loss_val/len(valid_loader)
            valid_curve.append(loss_val_mean)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))
        resnet18_ft.train()

train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

不进行finetune的运行结果:
在这里插入图片描述
不进行finetune,最终训练结果,loss为0.5729,accuracy为67.50%

迁移学习加载参数后运行结果:
在这里插入图片描述
可以看到,通过迁移学习,加载已学习的参数后进行训练,准确率是直接从百分之六十多开始增加,而且很快达到了一个较高的准确率,所以,使用finetune能使得模型更快进行训练

法1 : 冻结卷积层——运行结果:
在这里插入图片描述
由上图可知,通过冻结卷积层,在迭代过程中,卷积层的参数是不变的

法2 : conv 小学习率——运行结果:
在这里插入图片描述
通过卷积层使用较小的学习率训练结果,这里设置的是0.0001

法2 : conv 小学习率(学习率设置为0)——运行结果:
在这里插入图片描述
可以看到,学习率设置为0,卷积层的参数在训练过程中是不变的,此时该方法的效果与法1相同

发布了105 篇原创文章 · 获赞 9 · 访问量 7798

猜你喜欢

转载自blog.csdn.net/qq_36825778/article/details/104213056
今日推荐