PyTorch migration learning-ant bee classification on private data set

1. Two main scenarios of transfer learning

  1. Fine-tune CNN: use a pre-trained network to initialize your own network instead of random initialization, and then train it
  2. Think of CNN as a fixed feature extractor: fix the previous layer, rewrite the last fully connected layer, only this new layer will be trained

Let's modify the pre-trained resnet18 network to train on a private data set to classify ants and bees

2. Data set download

The data set used here contains about 120 training pictures of ants and bees each, and 75 verification pictures each. Since there are very few data samples, it is difficult to have satisfactory results if a network is initialized from 0 for training. At this time, transfer learning comes in handy. Data set download address , after downloading, unzip it to the project directory

3. Import related packages

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import torchvision.transforms as transforms
import time
import os
import copy

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

4. Load data

PyTorch provides the torchvision.datasets.ImageFolder method to load private data sets:

# 训练数据集需要扩充和归一化
# 验证数据集仅需要归一化
data_transforms = {
    
    
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'hymenoptera_data'

image_datasets = {
    
    
    x: torchvision.datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) 
    for x in ['train', 'val']
    }

dataloaders = {
    
    
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,  shuffle=True, num_workers=4)
    for x in ['train', 'val']
    }

dataset_sizes = {
    
    
    x: len(image_datasets[x])
    for x in ['train', 'val']
    }

class_names = image_datasets['train'].classes

5. Define a general training function to get the optimal parameters

# 训练模型函数,参数scheduler是一个 torch.optim.lr_scheduler 学习速率调整类对象
def train_model(model, criterion, optimizer, scheduler, num_epochs=2):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('-' * 20)
        print('Epoch {}/{}'.format(epoch+1, num_epochs))

        # 每个epoch都有一个训练和验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()   # 训练模式
            else:
                model.eval()    # 验证模式

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 训练阶段开启梯度跟踪
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # 仅在训练阶段进行后向+优化
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        scheduler.step()

                # 统计
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # 记录最好的状态
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    print('-' * 20)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed%60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 返回最佳参数的模型
    model.load_state_dict(best_model_wts)
    return model

6. Scene 1: Fine-tuning CNN

Here we use resnet18 as our initial network and continue to train the pre-trained model on our own data set. The difference is that we modify the output dimension of the final fully connected layer of the original network to 2, because we only need to predict that it is an ant Still bees, the original network output dimension is 1000, and 1000 categories are predicted:

net = torchvision.models.resnet18(pretrained=True)     # 加载resnet网络结构和预训练参数
num_ftrs = net.fc.in_features      # 提取fc层的输入参数
net.fc = nn.Linear(num_ftrs, 2)    # 修改输出维度为2

net = net.to(device)

# 使用分类交叉熵 Cross-Entropy 作损失函数,动量SGD做优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 每5个epochs衰减一次学习率 new_lr = old_lr * gamma ^ (epoch/step_size)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# 训练模型
net = train_model(net, criterion, optimizer, lr_scheduler, num_epochs=10)

Insert picture description here

7. Scenario 2: CNN as a fixed feature extractor

Here we set requirements_grad == False to freeze all networks except the last layer, so that their gradients will not be calculated and the parameters will not be updated during backpropagation:

net = torchvision.models.resnet18(pretrained=True)
# 通过设置requires_grad = False来冻结参数,这样在反向传播的时候他们的梯度就不会被计算
for param in net.parameters():
    param.requires_grad = False

# 新连接层参数默认requires_grad=True
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 2)

net = net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.fc.parameters(), lr=0.001, momentum=0.9)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

net = train_model(net, criterion, optimizer, lr_scheduler, num_epochs=20)

Insert picture description here

Guess you like

Origin blog.csdn.net/zzh2910/article/details/103987523