使用PyTorch训练Cifar10

训练集5000张图片,每类500张,验证集1000张,每类100张。图片命名格式如下图所示。

训练集、验证集分为两个文件夹存放。

class AlexNet(nn.Module):

    def __init__(self):
        super(AlexNet, self).__init__()

        #input size [3*227*227]

        self.conv1 = nn.Conv2d(3, 96, 11, stride=4)
        self.conv2 = nn.Conv2d(96, 256, 5, padding=2)
        self.conv3 = nn.Conv2d(256, 384, 3, padding=1)
        self.conv4 = nn.Conv2d(384, 384, 3, padding=1)
        self.conv5 = nn.Conv2d(384, 256, 3, padding=1)

        self.fc6 = nn.Linear(256 * 6 * 6, 4096)
        self.fc7 = nn.Linear(4096, 4096)
        self.fc8 = nn.Linear(4096, 10)

    def forward(self, x):
        c1 = self.conv1(x)
        r1 = F.relu(c1)
        p1 = F.max_pool2d(r1, (3,3), stride=2)

        c2 = self.conv2(p1)
        r2 = F.relu(c2)
        p2 = F.max_pool2d(r2, (3,3), stride=2)

        c3 = self.conv3(p2)
        r3 = F.relu(c3)

        c4 = self.conv4(r3)
        r4 = F.relu(c4)

        c5 = self.conv5(r4)
        r5 = F.relu(c5)
        p5 = F.max_pool2d(r5, (3,3), stride=2)

        flatten = p5.view(-1, 256*6*6)

        f6 = self.fc6(flatten)
        r6 = F.relu(f6)
        d6 = F.dropout(r6)

        f7 = self.fc7(d6)
        r7 = F.relu(f7)
        d7 = F.dropout(r7)

        f8 = self.fc8(d7)

        return f8

Torch里面好像没有LRN层。也没有Crop,直接227*227大小输进去。

  • 然后,实现数据加载功能。想象每次随机从一个文件夹取batch_size张图片,自己写代码挺麻烦,我们使用Torch提供的DataLoader类实现数据加载接口。
class MyDataset_Cifar10(Dataset):
    def __init__(self, image_dir):
        self.root_dir = image_dir
        self.name_list = os.listdir(image_dir)
        self.label_list = []
        for i in self.name_list:
            name,id = i.split('_')
            id = id[:-4]
            self.label_list.append(id)

    def __len__(self):
        return len(self.name_list)

    def __getitem__(self, item):
        img = cv.imread(self.root_dir + self.name_list[item])
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
        img = cv.resize(img, (227, 227))
        img = img.transpose((2, 0, 1))
        img = torch.tensor(img)
        label = self.label_list[item]
        label = torch.tensor(int(label))
        return img, label

主要需要告诉DataLoader你的数据在哪?所以需要传入image_dir。其次建立数据和标签的对应关系表,从而在__len__函数中得到数据的总量。最后,根据__getitem__函数的item项,返回一个数据和一个标签。注意,这里的数据和标签最好是能直接拿来训练的数据,而不是纯RGB数据,所以上面的代码进行了浮点型,缩放,通道转换,张量化处理。

  • 接着,进行训练。训练的流程:构造网络,损失函数,优化器,循环取Dataloader。train的时候用net.train(),val的时候用net.eval()。代码中注释掉的是动态学习率设置。
def train():
    max_epoch = 50
    test_epoch = 1
    display = 10
    train_batch_size = 128
    val_batch_size = 64

    net = AlexNet()
    net.cuda()

    best_model = net.state_dict()
    best_acc = 0.0

    cross_entropy_loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    train_set = MyDataset_Cifar10('/home/dl/DeepHashing/CIFAR10/train/')
    val_set = MyDataset_Cifar10('/home/dl/DeepHashing/CIFAR10/query/')
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
    valloader = torch.utils.data.DataLoader(val_set, batch_size=val_batch_size, shuffle=False)

    for e in range(max_epoch):
        print('Epoch {}/{}'.format(e,max_epoch))
        print('-' * 10)

        net.train()
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs = Variable(inputs.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = cross_entropy_loss(outputs, labels)

            loss.backward()
            optimizer.step()
            #scheduler.step()

            if i % display == 0:
                print('{} train loss:{} learning rate:{}'.
                      format(i*train_batch_size, loss.item(), optimizer.param_groups[0]['lr']))

        if e % test_epoch == 0:
            print('testing...')
            net.eval()
            acc = 0
            with torch.no_grad():
                for i, data in enumerate(valloader, 0):
                    inputs, labels = data
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())

                    outputs = net(inputs)
                    _, preds = torch.max(outputs.data, 1)
                    acc += torch.sum(preds == labels.data)

            acc = acc.item()/1000
            print('val acc:{}'.format(acc))

            if acc > best_acc:
                best_acc = acc
                best_model = net.state_dict()

    torch.save(best_model, './torch_test.pkl')

这段代码跑出来val是54.8%,网上其他alexnet的准确率在60-70这样,应该是对的,我们只用了5000样本。

  • 最后,测试模型。输入一张飞机,如图:

扫描二维码关注公众号,回复: 12826314 查看本文章
def test():
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    img = cv.imread('/home/dl/test.jpg')
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
    img = cv.resize(img, (227, 227))
    img = img.transpose((2, 0, 1))
    img = torch.tensor(img)
    img = img.unsqueeze(0)
    img = img.cuda()

    net = AlexNet()
    net.cuda()
    net.load_state_dict(torch.load('./torch_test.pkl'))
    net.eval()

    outputs = net(img)
    _, preds = torch.max(outputs.data, 1)
    print(classes[preds.item()])

输出frog,为什么输出frog?-_-

补充:头文件和主函数

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.optim.lr_scheduler
from torch.utils.data import DataLoader,Dataset
from torch.autograd import Variable
import os
import cv2 as cv
import numpy as np

if __name__=='__main__':
    #train()
    test()

补充:测试的模型(模型效果很拉胯,熟悉个代码流程)

链接:https://pan.baidu.com/s/1YSDNUbwytFhw7X9_mQWciA 
提取码:dux9 

猜你喜欢

转载自blog.csdn.net/XLcaoyi/article/details/109754507