【项目实践】猫十二分类

【数据科学项目实践】基于ResNet和Inception v3的猫十二分类迁移学习

一、项目背景

本项目来源于飞浆平台的图像分类学习赛。指路链接

  • 代码和结果来源于我的小组同学,没有做任何的改动,我这边仅做一个总结归纳,以便学习和复盘

简单把赛题Copy一下:

本场比赛要求参赛选手对十二种猫进行分类,属于CV方向经典的图像分类任务。图像分类任务作为其他图像任务的基石,可以让大家更快上手计算机视觉。

数据集

比赛数据集包含12种猫的图片,并划分为训练集与测试集。

训练集: 提供高清彩色图片以及图片所属的分类,共有2160张猫的图片,含标注文件。

测试集: 仅提供彩色图片,共有240张猫的图片,不含标注文件。

二、Baseline

2.1 准备阶段

主要是导入一些要用到的模块:

import os
import cv2
import torch
import torch.nn as nn
from torchvision import models,transforms
from torch.utils.data import DataLoader,Dataset
import numpy as np
from PIL import Image
from torch.optim import lr_scheduler
import copy

2.2 数据读取阶段

这个阶段就是如何将数据读取到模型中来,由于猫猫是图像数据,所以这边将其读取成数字图像一般是通过数组来存在内存中的,考虑到中间过程的可视化,我们通过PIL来读取Image类型的数据。这步可以写作:

x=np.fromfile(imgPath,dtype=np.float32) # 读取成ndarray
x=cv2.imdecode(x,1) # 将区间转化为[0,255]
img=PIL.Image.fromarray(x) # 读取成Image对象

在这里插入图片描述

上图中,左边的是Image类型的数据,右边是cv读取的数据,可以发现发生了颜色通道的调换。实际上,读取到cv这部分就好了,可以调用多窗口的imshow进行数据可视化。

我们现在拿到了猫猫图像!那么接下来就要拿到猫猫的标签啦,一般情况下,我们会将数据跟标签记录在一个文档里,每一行对应一个数据(图片)路径和一个标签:

# 文件标签
filelist=r"data_split_list.txt"
imgs,labels=[],[] # 存储列表

with open(filelist) as f:
    lines=[_.strip() for _ in f] # 去除空白
    np.random.shuffle(lines) # 随机打乱
    for l in lines:
        img_path,label=l.split('\t') # 获取图片路径和标签
        img=Image.fromarray(cv2.imdecode(np.fromfile(img_path,np.float32),1))
        imgs.append(img)
        labels.append(label)

我们将这部分工作封装成一个函数,就可以实现数据的读取了。

接下来的工作,就是将数据转化为PyTorch接受的格式啦。众所周知,PyTorch的模型训练跟推理一般是通过迭代一个DataLoader对象来进行的,而DataLoader对象的数据集是一个DataSet类。所以这里我们需要构建一个Dataset类啦:

class myData(Dataset):
    
    def __init__(self):
        super(myData,self).__init__()
        self.data=[]
    
    def __getitem__(self,x):
        return self.data[x]
    
    def __len__(self):
        return len(self.data)

嗯,把上面三个函数填完就阔以啦。

对于图像数据,我们需要应用一个transforms,这里做最简单的变换:转为Tensor,尺寸裁剪,标准化

self.transform=transforms.Compose(
    transforms.ToTensor(),
    transforms.Resize((299,299)),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
)

最终的Dataset如下:

class myData(Dataset):

    def __init__(self,kind):
        super(myData, self).__init__()
        self.mode=kind
        self.transform=transforms.Compose(
            transforms.ToTensor(),
            transforms.Resize((299,299)),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        )

        if kind=="test":
            self.imgs=self.load_origin_data()
        else:
            self.imgs,self.labels=self.load_origin_data()

    def __getitem__(self, item):
        if self.mode=="test":
            return self.transform(self.imgs[item])
        else:
            return self.transform(self.imgs[item]),torch.tensor(self.labels[item])

    def __len__(self):
        return len(self.imgs)

    def load_origin_data(self):
        filelist = './data/%s_split_list.txt' % self.mode
        imgs,labels=[],[]
        data_dir=os.getcwd()+"/data"
        if self.mode=='train' or self.mode=='val':
            with open(filelist) as f:
                lines=[_.strip() for _ in f]
                if self.mode=='train':
                    np.random.shuffle(lines)
                    for l in lines:
                        img_path,label=l.split('\t')
                        img_path=os.path.join(data_dir,img_path)
                        try:
                            img=Image.fromarray(cv2.imdecode(np.fromfile(img_path,dtype=np.float32),1))
                            imgs.append(img)
                            labels.append(int(label))
                        except Exception("The path %s"%img_path+" may be wrong") as e:
                            print(e)
                            continue
                    return imgs,labels
                elif self.mode=="test":
                    full_lines = os.listdir('data/cat_12_test/')
                    lines = [line.strip() for line in full_lines]
                    for img_path in lines:
                        img_path = os.path.join(data_dir, "cat_12_test/", img_path)
                        img = Image.open(img_path)
                        imgs.append(img)
                    return imgs

2.3 模型训练

我们刚刚说PyTorch的模型训练跟推理一般是通过迭代一个DataLoader对象来进行的,现在就是需要构建这个东西啦:

def get_Dataloader():
    img_datasets = {
    
    x: myData(x) for x in ['train', 'val', 'test']}
    dataset_sizes = {
    
    x: len(img_datasets[x]) for x in ['train', 'val', 'test']}

    train_loader = DataLoader(
        dataset=img_datasets['train'],
        batch_size=24,
        shuffle=True
    )

    val_loader = DataLoader(
        dataset=img_datasets['val'],
        batch_size=1,
        shuffle=False
    )

    test_loader = DataLoader(
        dataset=img_datasets['test'],
        batch_size=1,
        shuffle=False
    )

    dataloaders = {
    
    
        'train': train_loader,
        'val': val_loader,
        'test': test_loader
    }
    return dataset_sizes,dataloaders

接下来就是单纯的训练过程了。步骤总结如下:

  • 参数设置阶段
    • 设置GPU
    • 设置优化器、损失函数、学习策略
  • 训练过程
    • 迭代DataLoader
    • 优化器梯度清零
    • 模型推理
    • 误差计算
    • 反向传播
    • 更新优化器、学习率
  • 模型评估
    • 计算每轮的误差累计值、精度
    • 选择最优精度并进行模型保存
def Train(model,criterion,optimizer,scheduler,num_epoches=25):
    dataset_sizes,dataloaders=get_Dataloader()
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    best_model_wts=copy.deepcopy(model.state_dict())
    best_acc=0.0

    for epoch in range(num_epoches):
        print("Epoch {}/{}".format(epoch+1,num_epoches))

        for phase in ['train','val']:
            if phase=="train":
                model.train()
            else:
                model.eval()

            trian_loss=0.0
            train_corrects=0

            for inputs,labels in dataloaders[phase]:
                inputs,labels=inputs.to(device),labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase=="train"):
                    # 上下文管理器,参数是Bool,用于确定是否对Block内的语句进行求导
                    y_pre=model(inputs)
                    _,y_pre=torch.max(y_pre,1)
                    loss=criterion(y_pre,labels)

                    if phase=="train":
                        loss.backward()
                        optimizer.step()

                trian_loss+=loss.item()*inputs.size(0)
                train_corrects+=torch.sum(y_pre==labels)
            if phase=="train":
                scheduler.step()

            epoch_loss=trian_loss/dataset_sizes[phase]
            epoch_acc=train_corrects.float()/dataset_sizes[phase]

            print("{} Loss :{:.4f} Acc {:.4}".format(phase,epoch_loss,epoch_acc))

            if phase=="val" and epoch_acc>best_acc:
                best_acc=epoch_acc
                best_model_wts=copy.deepcopy(model.state_dict())
    print("Best val Acc : {:4f}".format(best_acc))
    model.load_state_dict(best_model_wts)
    return model

三、迁移学习

迁移学习(Transfer Learning)就是利用预训练好的大模型参数去学习其他数据的分布。

这个过程我们一般不希望原始模型参数改变,因而一般需要做如下工作:

for param in model.parameters():
    param.requires_grad=False

然后,我们需要构架最后一层全连接层,用来学习新的数据集:

model.fc=nn.Linear(2048,num_classes)

也就是最后需要训练的就是这个全连接层了。

def Inception(device):
    # 用训练好的模型进行迁移
    model_ft=models.inception_v3(pretrained=True)
    # model_ft=models.resnet50(pretrained=True)
    # model_ft=models.alexnet(pretrained=True)

    num_ftrs=model_ft.fc.in_features
    model_ft.fc=nn.Linear(num_ftrs,12) # 设置全连接层最终结果
    
    model_ft=model_ft.to(device)

    cirterion=nn.CrossEntropyLoss()
    optimizer_ft=torch.optim.SGD(model_ft.parameters(),lr=0.001,momentum=0.9)
    exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=5,gamma=0.1)
    model_ft=Train(model_ft,cirterion,optimizer_ft,exp_lr_scheduler,num_epoches=30)

四、结果分析

  • Inception

    Epoch 30/30
    train Loss: 0.1065 Acc: 0.9858
    val Loss: 0.3026 Acc: 0.8983
    Best val Acc: 0.918336
    
  • AlexNet

    Epoch 30/30
    train Loss: 0.1403 Acc: 0.9601
    val Loss: 0.6815 Acc: 0.7750
    Best val Acc: 0.779661
    
  • ResNet50

    Epoch 30/30
    train Loss: 0.0480 Acc: 0.9973
    val Loss: 0.3157 Acc: 0.9060
    Best val Acc: 0.909091
    

中间部分特征图的结果如下:

在这里插入图片描述

特征图嘛,主打的就是一个抽象。可以发现同一张图经过不同的卷积核作用后,有了全新的高维特征,这些特征也主打的就是一个难以解释,反正就看个乐。

在这里插入图片描述

基本上7个epoch就收敛了。

猜你喜欢

转载自blog.csdn.net/qq_45957458/article/details/130877398