PyTorchを使用してCifar10をトレーニングします

トレーニングセットに5000枚の画像、各カテゴリに500枚の画像、検証セットに1000枚の画像、各カテゴリに100枚の画像。画像の命名形式を下図に示します。

トレーニングセットと検証セットは2つのフォルダーに保存されます。

 

class AlexNet(nn.Module):

    def __init__(self):
        super(AlexNet, self).__init__()

        #input size [3*227*227]

        self.conv1 = nn.Conv2d(3, 96, 11, stride=4)
        self.conv2 = nn.Conv2d(96, 256, 5, padding=2)
        self.conv3 = nn.Conv2d(256, 384, 3, padding=1)
        self.conv4 = nn.Conv2d(384, 384, 3, padding=1)
        self.conv5 = nn.Conv2d(384, 256, 3, padding=1)

        self.fc6 = nn.Linear(256 * 6 * 6, 4096)
        self.fc7 = nn.Linear(4096, 4096)
        self.fc8 = nn.Linear(4096, 10)

    def forward(self, x):
        c1 = self.conv1(x)
        r1 = F.relu(c1)
        p1 = F.max_pool2d(r1, (3,3), stride=2)

        c2 = self.conv2(p1)
        r2 = F.relu(c2)
        p2 = F.max_pool2d(r2, (3,3), stride=2)

        c3 = self.conv3(p2)
        r3 = F.relu(c3)

        c4 = self.conv4(r3)
        r4 = F.relu(c4)

        c5 = self.conv5(r4)
        r5 = F.relu(c5)
        p5 = F.max_pool2d(r5, (3,3), stride=2)

        flatten = p5.view(-1, 256*6*6)

        f6 = self.fc6(flatten)
        r6 = F.relu(f6)
        d6 = F.dropout(r6)

        f7 = self.fc7(d6)
        r7 = F.relu(f7)
        d7 = F.dropout(r7)

        f8 = self.fc8(d7)

        return f8

トーチにはLRNレイヤーがないようです。切り抜きはありません。227* 227サイズを直接入力するだけです。

 

  • 次に、データ読み込み機能を実現します。毎回フォルダからbatch_sizeの写真をランダムに取得することを想像してみてください。自分でコードを書くのは面倒です。Torchが提供するDataLoaderクラスを使用して、データ読み込みインターフェイスを実装します。
class MyDataset_Cifar10(Dataset):
    def __init__(self, image_dir):
        self.root_dir = image_dir
        self.name_list = os.listdir(image_dir)
        self.label_list = []
        for i in self.name_list:
            name,id = i.split('_')
            id = id[:-4]
            self.label_list.append(id)

    def __len__(self):
        return len(self.name_list)

    def __getitem__(self, item):
        img = cv.imread(self.root_dir + self.name_list[item])
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
        img = cv.resize(img, (227, 227))
        img = img.transpose((2, 0, 1))
        img = torch.tensor(img)
        label = self.label_list[item]
        label = torch.tensor(int(label))
        return img, label

主にデータがどこにあるかをDataLoaderに伝える必要がありますか?したがって、image_dirを渡す必要があります。次に、__ len__関数でデータの合計量を取得するために、データとラベルの対応テーブルを確立します。最後に、__ getitem__関数のアイテムアイテムに従って、データとラベルが返されます。ここでのデータとラベルは、純粋なRGBデータではなく、トレーニングに直接使用できるデータであることが望ましいため、上記のコードは浮動小数点、スケーリング、チャネル変換、および量子化処理を実行することに注意してください。

 

  • 次に、トレーニングを実行します。トレーニングプロセス:ネットワーク、損失関数、オプティマイザー、およびループを構築して、データローダーを取得します。trainにはnet.train()を使用し、valにはnet.eval()を使用します。コードでコメントアウトされているのは、動的学習率の設定です。
def train():
    max_epoch = 50
    test_epoch = 1
    display = 10
    train_batch_size = 128
    val_batch_size = 64

    net = AlexNet()
    net.cuda()

    best_model = net.state_dict()
    best_acc = 0.0

    cross_entropy_loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    train_set = MyDataset_Cifar10('/home/dl/DeepHashing/CIFAR10/train/')
    val_set = MyDataset_Cifar10('/home/dl/DeepHashing/CIFAR10/query/')
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
    valloader = torch.utils.data.DataLoader(val_set, batch_size=val_batch_size, shuffle=False)

    for e in range(max_epoch):
        print('Epoch {}/{}'.format(e,max_epoch))
        print('-' * 10)

        net.train()
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs = Variable(inputs.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = cross_entropy_loss(outputs, labels)

            loss.backward()
            optimizer.step()
            #scheduler.step()

            if i % display == 0:
                print('{} train loss:{} learning rate:{}'.
                      format(i*train_batch_size, loss.item(), optimizer.param_groups[0]['lr']))

        if e % test_epoch == 0:
            print('testing...')
            net.eval()
            acc = 0
            with torch.no_grad():
                for i, data in enumerate(valloader, 0):
                    inputs, labels = data
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())

                    outputs = net(inputs)
                    _, preds = torch.max(outputs.data, 1)
                    acc += torch.sum(preds == labels.data)

            acc = acc.item()/1000
            print('val acc:{}'.format(acc))

            if acc > best_acc:
                best_acc = acc
                best_model = net.state_dict()

    torch.save(best_model, './torch_test.pkl')

このコードの値は54.8%であり、インターネット上の他のalexnetの精度は60〜70であり、これは正しいはずです。5000サンプルのみを使用しました。

 

  • 最後に、モデルをテストします。図に示すように、平面を入力します。

def test():
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    img = cv.imread('/home/dl/test.jpg')
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
    img = cv.resize(img, (227, 227))
    img = img.transpose((2, 0, 1))
    img = torch.tensor(img)
    img = img.unsqueeze(0)
    img = img.cuda()

    net = AlexNet()
    net.cuda()
    net.load_state_dict(torch.load('./torch_test.pkl'))
    net.eval()

    outputs = net(img)
    _, preds = torch.max(outputs.data, 1)
    print(classes[preds.item()])

出力カエル、なぜ出力カエル?-_-

 

補足:ヘッダーファイルと主な機能

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.optim.lr_scheduler
from torch.utils.data import DataLoader,Dataset
from torch.autograd import Variable
import os
import cv2 as cv
import numpy as np

if __name__=='__main__':
    #train()
    test()

補足:テストされたモデル(モデルは非常にヒップで、コードフローに精通しています)

リンク:https 
://pan.baidu.com/s/1YSDNUbwytFhw7X9_mQWciA抽出コード:dux9 

おすすめ

転載: blog.csdn.net/XLcaoyi/article/details/109754507