SWA combat: Use SWA for fine-tuning to improve model generalization

Summary

Paper link: https://arxiv.org/abs/1803.05407.pdf

Official code: https://github.com/timgaripov/swa

Paper translation: [Part 32] SWA: Average weights lead to wider optimization and better generalization - Programmer Sought

In short, SWA is to average multiple checkpoints during the training process to improve the generalization performance of the model. Note the training process part iiThe checkpoint of i epoch iswi w_{i}wi, in general, we will choose the model wn w_{n} of the last epoch in the training processwnOr the best model wi ∗ w^{*}_{i} on the validation setwias the final model. However, SWA generally uses a higher fixed learning rate or periodic learning rate for additional training for a period of time at the end, and takes the average of multiple checkpoints.

Example of using pytorch:

from torch.optim.swa_utils import AveragedModel, SWALR
# 采用SGD优化器
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
# 随机权重平均SWA,实现更好的泛化
swa_model = AveragedModel(model).to(device)
# SWA调整学习率
swa_scheduler = SWALR(optimizer, swa_lr=1e-6)
for epoch in range(1, epoch + 1):
    for batch_idx, (data, target) in enumerate(train_loader):   
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        # 在反向传播前要手动将梯度清零
        optimizer.zero_grad()
        output = model(data)
        #计算losss
        loss = train_criterion(output, targets)
        # 反向传播求解梯度
        loss.backward()
        optimizer.step()
        lr = optimizer.state_dict()['param_groups'][0]['lr']   
    swa_model.update_parameters(model)
    swa_scheduler.step()
# 最后更新BN层参数
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
# 保存结果
torch.save(swa_model.state_dict(), "last.pt")

The above code shows the main code of SWA, the steps of implementation:

1. Define the SGD optimizer.

2. Define SWA.

3. Define SWALR and adjust the learning rate of the model.

4. Start the training and wait for the training to complete.

5. Update the parameters of the model and update the learning rate in each epoch.

6. After the training is completed, update the parameters of the BN layer.

Detailed implementation process

environment

pyotrch:1.10

Prepare

Before starting today's code, we need to prepare the trained model. Then we can start today's code.

Implementation process

Define the model and load the trained model, the code is as follows:

    model_ft = efficientnet_b1(pretrained=True)
    print(model_ft)
    num_ftrs = model_ft.classifier.in_features
    model_ft.classifier = nn.Linear(num_ftrs, classes)
    model_ft.to(DEVICE)
    model_ft = torch.load(model_path)
    print(model_ft)
    fine_epoch = 80
    fine_tune(model_ft, DEVICE, train_loader, test_loader, criterion_train, criterion_val, fine_epoch, mixup_fn,
              use_amp)

Define the model as efficientnet_b1, which should be consistent with the trained model.

If the entire model is saved, use torch.load(model_path) to load the model. If only the weight information is saved, use model_ft=load_state_dict(torch.load(model_path)) to load the model.

Then, set the epoch of fine to 80.

Next, let's look at the contents of the fine_tune function.

 # 采用SGD优化器
    optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
    if use_amp:
        model, optimizer = amp.initialize(model_ft, optimizer, opt_level="O1")  # 这里是“欧一”,不是“零一”

Define the optimizer as SGD.

If mixed precision is used, amp is initialized.

 # 随机权重平均SWA,实现更好的泛化
 swa_model = AveragedModel(model).to(device)
 # SWA调整学习率
 swa_scheduler = SWALR(optimizer, swa_lr=1e-6)

Initialize SWA.

Use SWALR to adjust the learning rate.

Next, loop the epoch, which is a more general logic.

 for epoch in range(1, epoch + 1):
        model.train()
        train_loss = 0
        total_num = len(train_loader.dataset)
        print(total_num, len(train_loader))
        for batch_idx, (data, target) in enumerate(train_loader):
            if len(data) % 2 != 0:
                print(len(data))
                data = data[0:len(data) - 1]
                target = target[0:len(target) - 1]
                print(len(data))
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            samples, targets = mixup_fn(data, target)
            output = model(samples)
            loss = train_criterion(output, targets)
            optimizer.zero_grad()
            if use_amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print_loss = loss.data.item()
            train_loss += print_loss
            if (batch_idx + 1) % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(
                    epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                           100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
        swa_model.update_parameters(model)
        swa_scheduler.step()

The main steps are:

1. Calculate loss.

2. Whether to use amp mixed precision, if using mixed precision, use scaled_loss backpropagation to obtain the gradient, otherwise direct loss backpropagation to obtain the gradient.

3. swa_model.update_parameters(model) updates the parameters of swa_model.

4. swa_scheduler.step() updates the learning rate.

Wait for all epochs to complete.

torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
torch.save(swa_model.state_dict(), "last.pt")

Update the BN layer parameters.

Then save the weights of the model. Note: Only the weights of the model can be saved here, not the entire model.

After completion, you can test and execute the code:

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
from torchvision.models.mobilenetv3 import mobilenet_v3_large
import torch.nn as nn
from torch.optim.swa_utils import AveragedModel, SWALR
from timm.models.efficientnet import efficientnet_b1
import numpy as np

def show_outputs(output):

    output_sorted = sorted(output, reverse=True)
    top5_str = '-----TOP 5-----\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        for j in range(len(index)):
            if (i + j) >= 5:
                break
            if value > 0:
                topi = '{}: {}\n'.format(index[j], value)
            else:
                topi = '-1: 0.0\n'
            top5_str += topi
    print(top5_str)

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = efficientnet_b1(pretrained=True)

num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 8)
swa_model = AveragedModel(model)
swa_model.load_state_dict(torch.load("last.pt"))
swa_model.to(DEVICE)
swa_model.eval()

path = 'test/'
testList = os.listdir(path)
for file in testList:
    img = Image.open(path + file)
    img = transform_test(img)
    img.unsqueeze_(0)
    img = Variable(img).to(DEVICE)
    out = swa_model(img)
    out = out.data.cpu().numpy()[0]
    print(file)
    show_outputs(out)

The test code here is no different from the previous writing method, the only difference is:

Redefine the model and load the weights.
Running result:
image-20220425210850314
Complete code:
https://download.csdn.net/download/hhhhhhhhhwwwwwwwwwwww/85223146

Guess you like

Origin blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/124414939