CIFAR10 データセット戦闘を簡単に実装するための数十行のコード — ResNet50 簡単な実装/pytorch

CIFAR10 データセット戦闘を簡単に実装するための数十行のコード — ResNet50 簡単な実装/pytorch

1 CIFAR-10

CIFAR-10 は、機械学習およびコンピューター ビジョン アルゴリズムをトレーニングするための画像のコレクションです。これには 60,000 枚の 32x32 カラー画像が含まれており、10 のカテゴリに分かれており、各カテゴリに 6,000 枚の画像が含まれています。10 カテゴリは、飛行機、車、鳥、猫、鹿、犬、カエル、馬、船、トラックです。CIFAR-10 は、2009 年にリリースされた ImageNet (1,400 万枚の小さな画像のデータセット) の注釈付きサブセットで、機械学習研究で最も広く使用されているデータセットの 1 つです。
CIFAR-10の一部

2 データセットを構築する

データセットのダウンロード速度が遅すぎる場合は、公式 Webサイトにアクセスして自分でダウンロードできます。

import torch, os
from torch import nn
from torch.utils.data import DataLoader as DataLoader
import torchvision

# 数据增强
transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),

        torchvision.transforms.RandomResizedCrop(
                (224, 224), scale=(0.9, 1), ratio=(0.9, 1.1)),
        torchvision.transforms.ColorJitter(
                brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225]),
])
test_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225]),
        ])

def load_cifar10(is_train=True, transform=None, batch_size=128):
    dataset = torchvision.datasets.CIFAR10(root="../Dataset", train=is_train,
                                           transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=is_train)
    return dataloader


train_iter = load_cifar10(True, transform, batch_size)
test_iter = load_cifar10(False, test_transform, batch_size)

3 ResNet50の設計

シンプルな設計、バックボーン ネットワーク resnet50 の機能抽出、そして最後にソフトマックス

class ResNet(nn.Module):
    def __init__(self, backend='resnet18'):
        self.backend = backend  # 卷积网络的后端
        # 调用父类的初始化方法
        super(ResNet, self).__init__()

        self.feature_extractor = getattr(torchvision.models, backend)(pretrained=True
                                                                      )
        self.cnn = nn.Sequential(
            self.feature_extractor.conv1,
            self.feature_extractor.bn1,
            self.feature_extractor.relu,
            self.feature_extractor.maxpool,
            self.feature_extractor.layer1,
            self.feature_extractor.layer2,
            self.feature_extractor.layer3,
            self.feature_extractor.layer4,
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.softmax = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.feature_extractor.fc.in_features, 10)
        )

    def forward(self, x):
        features = self.cnn(x)
        y = self.softmax(features)
        return y

backend = 'resnet50'
net = ResNet(backend)

4 トレーニングを始める

from d2l import torch as d2l
from tqdm import tqdm
from torchsummary import summary

def train(net, train_iter, test_iter, num_epochs, lr, device):
    print('training on', device)
    net, resume_epoch = load_model(net, backend)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[resume_epoch+1, num_epochs+resume_epoch],ylim=[0, 1.0],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(resume_epoch, num_epochs+resume_epoch):
        print('epochs:',epoch+1)
        metric = d2l.Accumulator(3)
        net.train()
        iterator = tqdm(train_iter)
        for i, (X, y) in enumerate(iterator):
            timer.start()
            optimizer.zero_grad()
            # print(X.shape)
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
            status = f"epoch: {
      
      epoch}, loss: {
      
      train_l:.3f}, train_acc: {
      
      train_acc:.3f}"
            iterator.set_description(status)       

        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
        print(f'loss {
      
      train_l:.3f}, train acc {
      
      train_acc:.3f}, '
              f'test acc {
      
      test_acc:.3f}')

    print(f'loss {
      
      train_l:.3f}, train acc {
      
      train_acc:.3f}, '
          f'test acc {
      
      test_acc:.3f}')
    print(f'{
      
      metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {
      
      str(device)}')
          
    torch.save({
    
    'model_state_dict': net.state_dict(),
                'epoch': epoch+1},
    			os.path.join('model/', "cifar10_" + backend + ".params"))
    d2l.plt.show()

def load_model(net, backend):
    if os.path.exists('model/' + "cifar10_" + backend + ".params"):
        info = torch.load('model/' + "cifar10_" + backend + ".params")
        net.load_state_dict(info['model_state_dict'])
        resume_epoch = info['epoch']
        print("cifar10_" + backend + ": Load Successful.")
    else:
        print("File not found.")
        resume_epoch = 0
    return net, resume_epoch
    ```

```python
if __name__ == '__main__':
	device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    lr = 1e-3
    batch_size = 128
	# 用这个函数可以查看设计的网络结构
    # summary(net, (1, 3, 224, 224), device='cuda')
    net = load_model(net, backend)
    train(net, train_iter, test_iter, 12, lr, device)
    torch.save(net.state_dict(),
               os.path.join('model/', "cifar10_" + backend + ".params"))

トレーニングプロセス
トレーニング プロセスを視覚化します。
トレーニングプロセス

5 件の結果

結果

おすすめ

転載: blog.csdn.net/Dec1steee/article/details/130735974