[Pytorch Neural Network Practical Case] 25 (with data enhancement) Recognizing multiple birds based on transfer learning (CUB-200 dataset)

1 Data augmentation

Among the EficientNet series models with the best classification results, the EfficientNet-B7 version of the model is trained using random data augmentation.

The RandAugment method is also the current mainstream data enhancement method. Using the RandAugment method for training will improve the accuracy of the model.

2 RandAugment

2.1 Introduction to the RandAugment method

The RandAugment method is a new data augmentation method, which is simpler and better to use than the automatic data augmentation (AutoOAugment) method. It can directly replace the AutoAugment method in the original training framework.

2.1.1 Tip

The AuoAugment method contains more than 30 parameters and can perform various transformations on the image data (see the paper numbered 1805.09501 on the arXiv website).

2.2 The composition of the RandAugment method

The RandAugment method is based on the AutoAugment method, and more than 30 parameters are optimized and managed at the policy level, so that these more than 30 parameters are simplified into two parameters: N times of transformation of the picture and the intensity M of each transformation. Among them, the intensity M of each transformation is 0 to 10 (only integers), indicating the size of the enhancement and distortion of the original picture.

The RandAugment approach is result-oriented, making the data augmentation process more user-oriented. While reducing the computational consumption of AutoAugment, the enhanced effect becomes controllable. For details, please refer to related papers (see the paper numbered 1909.13719 on the arXIV website).

2.2 Code acquisition

https://github.com/heartInsert/randaugment
# 只有一个代码文件Rand_Augment,py,将其下载后,直接引入代码即可使用。

3 Cases in this section (recognition with data augmentation)

3.1 Case Introduction

Using transfer learning to fine-tune the pre-trained model to achieve data augmentation, let it learn the bird dataset and recognize a variety of birds.

3.2 Code implementation: load_data function loads image name and label loading----Transfer_bird2_Augmentation.py (Part 1)

import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt #plt 用于显示图片
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset,DataLoader
import torchvision
import torchvision.models as model
from torchvision.transforms import ToPILImage
import torchvision.transforms as transforms
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# 1.1 实现load_data函数加载图片名称与标签的加载,并使用torch.utils.data接口将其封装成程序可用的数据集类OwnDataset。
def load_dir(directory,labstart=0): # 获取所有directory中的所有图与标签
    # 返回path指定的文件夹所包含的文件或文件名的名称列表
    strlabels = os.listdir(directory)
    # 对标签进行排序,以便训练和验证按照相同的顺序进行:在不同的操作系统中,加载文件夹的顺序可能不同。目录不同的情况会导致在不同的操作系统中,模型的标签出现串位的现象。所以需要对文件夹进行排序,保证其顺序的一致性。
    strlabels.sort()
    # 创建文件标签列表
    file_labels = []
    for i,label in enumerate(strlabels):
        print(label)
        jpg_names = glob.glob(os.path.join(directory,label,"*.jpg"))
        print(jpg_names)
        # 加入列表
        file_labels.extend(zip(jpg_names, [i + labstart] * len(jpg_names)))
    return file_labels,strlabels

def load_data(dataset_path): # 定义函数load_data函数完成对数据集中图片文件名称和标签的加载。
    # 该函数可以实现两层文件夹的嵌套结构。其中,外层结构使用load_data函数进行遍历,内层结构使用load_dir函进行遍历。
    sub_dir = sorted(os.listdir(dataset_path)) # 跳过子文件夹:在不同的操作系统中,加载文件夹的顺序可能不同。目录不同的情况会导致在不同的操作系统中,模型的标签出现串位的现象。所以需要对文件夹进行排序,保证其顺序的一致性。
    start = 1 # 第0类是none
    tfile_lables,tstrlabels = [],['none'] # 在制作标签时,人为地在前面添加了一个序号为0的none类。这是一个训练图文类模型的技巧,为了区分模型输出值是0和预测值是0这两种情况。
    for i in sub_dir:
        directory = os.path.join(dataset_path,i)
        if os.path.isdir(directory) == False: # 只处理文件夹中的数据
            print(directory)
            continue
        file_labels,strlables = load_dir(directory,labstart=start)
        tfile_lables.extend(file_labels)
        tstrlabels.extend(strlables)
        start = len(strlables)
    # 将数据路径与标签解压缩,把数据路径和标签解压缩出来
    filenames,labels = zip(*tfile_lables)
    return filenames, labels, tstrlabels

3.3 Code Implementation: Custom Dataset Class OwnDataset----Transfer_bird2_Augmentation.py (Part 2)

# 1.2 实现自定义数据集OwnDataset
def default_loader(path) : # 定义函数加载图片
    return Image.open(path).convert('RGB')

class OwnDataset(Dataset): # 复用性较强,可根据自己的数据集略加修改使用
    # 在PyTorch中,提供了一个torch.utis.data接口,可以用来对数据集进行封装。在实现时,只需要继承torch.utis.data.Dataset类,并重载其__gettem__方法。
    # 在使用时,框架会向__gettem__方法传入索引index,在__gettem__方法内部根据指定index加载数据,并返回。
    def __init__(self,img_dir,labels,indexlist=None,transform=transforms.ToTensor(),loader=default_loader,cache=True): # 初始化
        self.labels = labels # 存放标签
        self.img_dir = img_dir # 样本图片文件名
        self.transform = transform # 预处理方法
        self.loader = loader # 加载方法
        self.cache = cache # 缓存标志
        if indexlist is None: # 要加载的数据序列
            self.indexlist = list(range(len(self.img_dir)))
        else:
            self.indexlist = indexlist
        self.data = [None] * len(self.indexlist) # 存放样本图片

    def __getitem__(self, idx): # 加载指定索引数据
        if self.data[idx] is None: # 第一次加载
            data = self.loader(self.img_dir[self.indexlist[idx]])
            if self.transform:
                data = self.transform(data)
        else:
            data = self.data[idx]
        if self.cache: # 保存到缓存里
            self.data[idx] = data
        return data,self.labels[self.indexlist[idx]]

    def __len__(self): # 计算数据集长度
        return len(self.indexlist)

3.4 Code combat: test dataset----Transfer_bird2_Augmentation.py (Part 3) [Data Augmentation Module]

# 1.3 测试数据集:在完成数据集的制作之后,编写代码对其进行测试。
# 数据增强模块
from Rand_Augment import  Rand_Augment
data_transform = { #定义数据的预处理方法
    'train':transforms.Compose([
        Rand_Augment(), # 数据增强的方法带入 仅此一处修改
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'val':transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
}
def Reduction_img(tensor,mean,std): #还原图片,实现了图片归一化的逆操作,显示数据集中的原始图片。
    dtype = tensor.dtype
    mean = torch.as_tensor(mean,dtype=dtype,device=tensor.device)
    std = torch.as_tensor(std,dtype=dtype,device=tensor.device)
    tensor.mul_(std[:,None,None]).add_(mean[:,None,None]) # 还原操作

dataset_path = r'./data/cub200/' # 加载数据集路径
filenames,labels,classes = load_data(dataset_path) # 调用load_data函数对数据集中图片文件名称和标签进行加载,其返回对象classes中包含全部的类名。
# 打乱数据顺序
# 110-115行对数据文件列表的序号进行乱序划分,分为测试数据集和训练数集两个索引列表。该索引列表会传入OwnDataset类做成指定的数据集。
np.random.seed(0)
label_shuffle_index = np.random.permutation(len(labels))
label_train_num = (len(labels)//10) * 8 # 划分训练数据集和测试数据集
train_list = label_shuffle_index[0:label_train_num]
test_list = label_shuffle_index[label_train_num:] # 没带:

train_dataset = OwnDataset(filenames,labels,train_list,data_transform['train'])# 实例化训练数据集
val_dataset = OwnDataset(filenames,labels,test_list,data_transform['val']) # 实例化测试数据集
# 实例化批次数据集:OwnDataset类所定义的数据集,其使用方法与PyTorch中的内置数据集的使用方法完全一致,配合DataLoader接口即可生成可以进行训练或测试的批次数据。具体代码如下。
train_loader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
val_loader = DataLoader(dataset=val_dataset,batch_size=32,shuffle=True)

sample = iter(train_loader) # 获取一批次数据,进行测试
images,labels = sample.next()
print("样本形状",np.shape(images))
print("标签个数",len(classes))
mulimgs = torchvision.utils.make_grid(images[:10],nrow=10) # 拼接多张图片
Reduction_img(mulimgs,[0.485,0.456,0.406],[0.229,0.224,0.225])
_img = ToPILImage()(mulimgs) # 将张量转化为图片
plt.axis('off')
plt.imshow(_img) # 显示
plt.show()
print(','.join('%5s' % classes[labels[j]] for j in range(len(images[:10]))))

output:

Sample shape torch.Size([32, 3, 224, 224]) Number of
labels 6

10 images in the output dataset

3.5 Code combat: obtaining and transforming the ResNet model----Transfer_bird2_Augmentation.py (Part 4)

# 1.4 获取并改造ResNet模型:获取ResNet模型,并加载预训练模型的权重。将其最后一层(输出层)去掉,换成一个全新的全连接层,该全连接层的输出节点数与本例分类数相同。
# 指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# get_ResNet函数,获取预训练模型,可指定pretrained=True来实现自动下载预训练模型,也可指定loadfile来从本地路径加载预训练模型。
def get_ResNet(classes,pretrained=True,loadfile=None):
    ResNet = model.resnet101(pretrained) # 自动下载官方的预训练模型
    if loadfile != None:
        ResNet.load_state_dict(torch.load(loadfile)) # 加载本地模型
    # 将所有的参数层进行冻结:设置模型仅最后一层可以进行训练,使模型只针对最后一层进行微调。
    for param in ResNet.parameters():
        param.requires_grad = False
    # 输出全连接层的信息
    print(ResNet.fc)
    x = ResNet.fc.in_features # 获取全连接层的输入
    ResNet.fc = nn.Linear(x,len(classes)) # 定义一个新的全连接层
    print(ResNet.fc) # 最后输出新的模型
    return ResNet
ResNet = get_ResNet(classes) # 实例化模型
ResNet.to(device=device)

3.6 Code combat: define loss function, training function and test function, fine-tune the last layer of the model----Transfer_bird2_Augmentation.py (Part 5)

# 1.5 定义损失函数、训练函数及测试函数,对模型的最后一层进行微调。
criterion = nn.CrossEntropyLoss()
# 指定新加的全连接层的学习率
optimizer = torch.optim.Adam([{'params':ResNet.fc.parameters()}],lr=0.01)
def train(model,device,train_loader,epoch,optimizer): # 定义训练函数
    model.train()
    allloss = []
    for batch_idx,data in enumerate(train_loader):
        x,y = data
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat,y)
        loss.backward()
        allloss.append(loss.item())
        optimizer.step()
    print('Train Epoch:{}\t Loss:{:.6f}'.format(epoch,np.mean(allloss))) # 输出训练结果

def test(model,device,val_loader): # 定义测试函数
    model.eval()
    test_loss = []
    correct = []
    with torch.no_grad(): # 使模型在运行时不进行梯度跟踪,可以减少模型运行时对内存的占用。
        for i,data in enumerate(val_loader):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            y_hat = model(x)
            test_loss.append(criterion(y_hat,y).item()) # 收集损失函数
            pred = y_hat.max(1,keepdim=True)[1] # 获取预测结果
            correct.append(pred.eq(y.view_as(pred)).sum().item()/pred.shape[0]) # 收集精确度
    print('\nTest:Average loss:{:,.4f},Accuracy:({:,.0f}%)\n'.format(np.mean(test_loss),np.mean(correct)*100)) # 输出测试结果

# 迁移学习的两个步骤如下
if __name__ == '__main__':
# 迁移学习步骤①:固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛。
    firstmodepth = './data/cub200/firstmodepth_1.pth' # 定义模型文件的地址
    if os.path.exists(firstmodepth) == False:
        print("—————————固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛—————————")
        for epoch in range(1,2): # 迭代两次
            train(ResNet,device,train_loader,epoch,optimizer)
            test(ResNet,device,val_loader)
        # 保存模型
        torch.save(ResNet.state_dict(),firstmodepth)

3.7 Code combat: Global fine-tuning of the model using degenerate learning rate----Transfer_bird2_Augmentation.py (Part 6)

# 1.6 使用退化学习率对模型进行全局微调
#迁移学习步骤②:使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节,即将模型的每层权重都设为可训练,并定义带有退化学习率的优化器。(1.6部分)
    secondmodepth = './data/cub200/firstmodepth_2.pth'
    optimizer2 = optim.SGD(ResNet.parameters(),lr=0.001,momentum=0.9) # 第198行代码定义带有退化学习率的SGD优化器。该优化器常用来对模型进行手动微调。有实验表明,使用经过手动调节的SGD优化器,在训练模型的后期效果优于Adam优化器。
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer2,step_size=2,gamma=0.9) # 由于退化学习率会在训练过程中不断地变小,为了防止学习率过小,最终无法进行权重需要对其设置最小值。当学习率低于该值时,停止对退化学习率的操作。
    for param in ResNet.parameters(): # 所有参数设计为可训练
        param.requires_grad = True
    if os.path.exists(secondmodepth):
        ResNet.load_state_dict(torch.load(secondmodepth)) # 加载本地模型
    else:
        ResNet.load_state_dict(torch.load(firstmodepth)) # 加载本地模型
    print("____使用较小的学习率,对全部模型进行训练,定义带有退化学习率的优化器______")
    for epoch in range(1,100):
        train(ResNet,device,train_loader,epoch,optimizer2)
        if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:
            exp_lr_scheduler.step()
            print("___lr:",optimizer2.state_dict()['param_groups'][0]['lr'])
        test(ResNet,device,val_loader)
    # 保存模型
    torch.save(ResNet.state_dict(),secondmodepth)

4 Code overview Transfer_bird2_Augmentation.py

import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt #plt 用于显示图片
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset,DataLoader
import torchvision
import torchvision.models as model
from torchvision.transforms import ToPILImage
import torchvision.transforms as transforms
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# 1.1 实现load_data函数加载图片名称与标签的加载,并使用torch.utils.data接口将其封装成程序可用的数据集类OwnDataset。
def load_dir(directory,labstart=0): # 获取所有directory中的所有图与标签
    # 返回path指定的文件夹所包含的文件或文件名的名称列表
    strlabels = os.listdir(directory)
    # 对标签进行排序,以便训练和验证按照相同的顺序进行:在不同的操作系统中,加载文件夹的顺序可能不同。目录不同的情况会导致在不同的操作系统中,模型的标签出现串位的现象。所以需要对文件夹进行排序,保证其顺序的一致性。
    strlabels.sort()
    # 创建文件标签列表
    file_labels = []
    for i,label in enumerate(strlabels):
        print(label)
        jpg_names = glob.glob(os.path.join(directory,label,"*.jpg"))
        print(jpg_names)
        # 加入列表
        file_labels.extend(zip(jpg_names, [i + labstart] * len(jpg_names)))
    return file_labels,strlabels

def load_data(dataset_path): # 定义函数load_data函数完成对数据集中图片文件名称和标签的加载。
    # 该函数可以实现两层文件夹的嵌套结构。其中,外层结构使用load_data函数进行遍历,内层结构使用load_dir函进行遍历。
    sub_dir = sorted(os.listdir(dataset_path)) # 跳过子文件夹:在不同的操作系统中,加载文件夹的顺序可能不同。目录不同的情况会导致在不同的操作系统中,模型的标签出现串位的现象。所以需要对文件夹进行排序,保证其顺序的一致性。
    start = 1 # 第0类是none
    tfile_lables,tstrlabels = [],['none'] # 在制作标签时,人为地在前面添加了一个序号为0的none类。这是一个训练图文类模型的技巧,为了区分模型输出值是0和预测值是0这两种情况。
    for i in sub_dir:
        directory = os.path.join(dataset_path,i)
        if os.path.isdir(directory) == False: # 只处理文件夹中的数据
            print(directory)
            continue
        file_labels,strlables = load_dir(directory,labstart=start)
        tfile_lables.extend(file_labels)
        tstrlabels.extend(strlables)
        start = len(strlables)
    # 将数据路径与标签解压缩,把数据路径和标签解压缩出来
    filenames,labels = zip(*tfile_lables)
    return filenames, labels, tstrlabels

# 1.2 实现自定义数据集OwnDataset
def default_loader(path) : # 定义函数加载图片
    return Image.open(path).convert('RGB')

class OwnDataset(Dataset): # 复用性较强,可根据自己的数据集略加修改使用
    # 在PyTorch中,提供了一个torch.utis.data接口,可以用来对数据集进行封装。在实现时,只需要继承torch.utis.data.Dataset类,并重载其__gettem__方法。
    # 在使用时,框架会向__gettem__方法传入索引index,在__gettem__方法内部根据指定index加载数据,并返回。
    def __init__(self,img_dir,labels,indexlist=None,transform=transforms.ToTensor(),loader=default_loader,cache=True): # 初始化
        self.labels = labels # 存放标签
        self.img_dir = img_dir # 样本图片文件名
        self.transform = transform # 预处理方法
        self.loader = loader # 加载方法
        self.cache = cache # 缓存标志
        if indexlist is None: # 要加载的数据序列
            self.indexlist = list(range(len(self.img_dir)))
        else:
            self.indexlist = indexlist
        self.data = [None] * len(self.indexlist) # 存放样本图片

    def __getitem__(self, idx): # 加载指定索引数据
        if self.data[idx] is None: # 第一次加载
            data = self.loader(self.img_dir[self.indexlist[idx]])
            if self.transform:
                data = self.transform(data)
        else:
            data = self.data[idx]
        if self.cache: # 保存到缓存里
            self.data[idx] = data
        return data,self.labels[self.indexlist[idx]]

    def __len__(self): # 计算数据集长度
        return len(self.indexlist)

# 1.3 测试数据集:在完成数据集的制作之后,编写代码对其进行测试。
# 数据增强模块
from Rand_Augment import  Rand_Augment
data_transform = { #定义数据的预处理方法
    'train':transforms.Compose([
        Rand_Augment(), # 数据增强的方法带入 仅此一处修改
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'val':transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
}
def Reduction_img(tensor,mean,std): #还原图片,实现了图片归一化的逆操作,显示数据集中的原始图片。
    dtype = tensor.dtype
    mean = torch.as_tensor(mean,dtype=dtype,device=tensor.device)
    std = torch.as_tensor(std,dtype=dtype,device=tensor.device)
    tensor.mul_(std[:,None,None]).add_(mean[:,None,None]) # 还原操作

dataset_path = r'./data/cub200/' # 加载数据集路径
filenames,labels,classes = load_data(dataset_path) # 调用load_data函数对数据集中图片文件名称和标签进行加载,其返回对象classes中包含全部的类名。
# 打乱数据顺序
# 110-115行对数据文件列表的序号进行乱序划分,分为测试数据集和训练数集两个索引列表。该索引列表会传入OwnDataset类做成指定的数据集。
np.random.seed(0)
label_shuffle_index = np.random.permutation(len(labels))
label_train_num = (len(labels)//10) * 8 # 划分训练数据集和测试数据集
train_list = label_shuffle_index[0:label_train_num]
test_list = label_shuffle_index[label_train_num:] # 没带:

train_dataset = OwnDataset(filenames,labels,train_list,data_transform['train'])# 实例化训练数据集
val_dataset = OwnDataset(filenames,labels,test_list,data_transform['val']) # 实例化测试数据集
# 实例化批次数据集:OwnDataset类所定义的数据集,其使用方法与PyTorch中的内置数据集的使用方法完全一致,配合DataLoader接口即可生成可以进行训练或测试的批次数据。具体代码如下。
train_loader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
val_loader = DataLoader(dataset=val_dataset,batch_size=32,shuffle=True)

sample = iter(train_loader) # 获取一批次数据,进行测试
images,labels = sample.next()
print("样本形状",np.shape(images))
print("标签个数",len(classes))
mulimgs = torchvision.utils.make_grid(images[:10],nrow=10) # 拼接多张图片
Reduction_img(mulimgs,[0.485,0.456,0.406],[0.229,0.224,0.225])
_img = ToPILImage()(mulimgs) # 将张量转化为图片
plt.axis('off')
plt.imshow(_img) # 显示
plt.show()
print(','.join('%5s' % classes[labels[j]] for j in range(len(images[:10]))))

# 1.4 获取并改造ResNet模型:获取ResNet模型,并加载预训练模型的权重。将其最后一层(输出层)去掉,换成一个全新的全连接层,该全连接层的输出节点数与本例分类数相同。
# 指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# get_ResNet函数,获取预训练模型,可指定pretrained=True来实现自动下载预训练模型,也可指定loadfile来从本地路径加载预训练模型。
def get_ResNet(classes,pretrained=True,loadfile=None):
    ResNet = model.resnet101(pretrained) # 自动下载官方的预训练模型
    if loadfile != None:
        ResNet.load_state_dict(torch.load(loadfile)) # 加载本地模型
    # 将所有的参数层进行冻结:设置模型仅最后一层可以进行训练,使模型只针对最后一层进行微调。
    for param in ResNet.parameters():
        param.requires_grad = False
    # 输出全连接层的信息
    print(ResNet.fc)
    x = ResNet.fc.in_features # 获取全连接层的输入
    ResNet.fc = nn.Linear(x,len(classes)) # 定义一个新的全连接层
    print(ResNet.fc) # 最后输出新的模型
    return ResNet
ResNet = get_ResNet(classes) # 实例化模型
ResNet.to(device=device)

# 1.5 定义损失函数、训练函数及测试函数,对模型的最后一层进行微调。
criterion = nn.CrossEntropyLoss()
# 指定新加的全连接层的学习率
optimizer = torch.optim.Adam([{'params':ResNet.fc.parameters()}],lr=0.01)
def train(model,device,train_loader,epoch,optimizer): # 定义训练函数
    model.train()
    allloss = []
    for batch_idx,data in enumerate(train_loader):
        x,y = data
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat,y)
        loss.backward()
        allloss.append(loss.item())
        optimizer.step()
    print('Train Epoch:{}\t Loss:{:.6f}'.format(epoch,np.mean(allloss))) # 输出训练结果

def test(model,device,val_loader): # 定义测试函数
    model.eval()
    test_loss = []
    correct = []
    with torch.no_grad(): # 使模型在运行时不进行梯度跟踪,可以减少模型运行时对内存的占用。
        for i,data in enumerate(val_loader):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            y_hat = model(x)
            test_loss.append(criterion(y_hat,y).item()) # 收集损失函数
            pred = y_hat.max(1,keepdim=True)[1] # 获取预测结果
            correct.append(pred.eq(y.view_as(pred)).sum().item()/pred.shape[0]) # 收集精确度
    print('\nTest:Average loss:{:,.4f},Accuracy:({:,.0f}%)\n'.format(np.mean(test_loss),np.mean(correct)*100)) # 输出测试结果

# 迁移学习的两个步骤如下
if __name__ == '__main__':
# 迁移学习步骤①:固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛。
    firstmodepth = './data/cub200/firstmodepth_1.pth' # 定义模型文件的地址
    if os.path.exists(firstmodepth) == False:
        print("—————————固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛—————————")
        for epoch in range(1,2): # 迭代两次
            train(ResNet,device,train_loader,epoch,optimizer)
            test(ResNet,device,val_loader)
        # 保存模型
        torch.save(ResNet.state_dict(),firstmodepth)
# 1.6 使用退化学习率对模型进行全局微调
#迁移学习步骤②:使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节,即将模型的每层权重都设为可训练,并定义带有退化学习率的优化器。(1.6部分)
    secondmodepth = './data/cub200/firstmodepth_2.pth'
    optimizer2 = optim.SGD(ResNet.parameters(),lr=0.001,momentum=0.9) # 第198行代码定义带有退化学习率的SGD优化器。该优化器常用来对模型进行手动微调。有实验表明,使用经过手动调节的SGD优化器,在训练模型的后期效果优于Adam优化器。
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer2,step_size=2,gamma=0.9) # 由于退化学习率会在训练过程中不断地变小,为了防止学习率过小,最终无法进行权重需要对其设置最小值。当学习率低于该值时,停止对退化学习率的操作。
    for param in ResNet.parameters(): # 所有参数设计为可训练
        param.requires_grad = True
    if os.path.exists(secondmodepth):
        ResNet.load_state_dict(torch.load(secondmodepth)) # 加载本地模型
    else:
        ResNet.load_state_dict(torch.load(firstmodepth)) # 加载本地模型
    print("____使用较小的学习率,对全部模型进行训练,定义带有退化学习率的优化器______")
    for epoch in range(1,100):
        train(ResNet,device,train_loader,epoch,optimizer2)
        if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:
            exp_lr_scheduler.step()
            print("___lr:",optimizer2.state_dict()['param_groups'][0]['lr'])
        test(ResNet,device,val_loader)
    # 保存模型
    torch.save(ResNet.state_dict(),secondmodepth)

Guess you like

Origin blog.csdn.net/qq_39237205/article/details/124097210