【15】宝可梦数据集基于迁移学习训练

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

1.迁移学习的概念

迁移学习的概念就是其实我们不必去重新的训练一个网络,而是我们可以基于其他的网络,借用这个网络的权重,然后稍微的去修改少层数的权重,从而达到一个比较好的效果。 在这里插入图片描述

常见的迁移学习方式:

  1. 载入权重后训练所有参数
  2. 载入权重后只训练最后几层参数
  3. 载入权重后在原网络基础上再添加一层全连接层,仅训练最后一个全连接层

在pytorch中,含有很多网络结构的预处理模型,这些就是迁移学习的基础。

在这里插入图片描述

对于【14】自定义宝可梦数据集节中,实现的自定义数据集,如果我们选择自己写的ResNet18/50网络结构去训练(详情见【15】ResNet结构的pytorch实现),以10个epoch为例,最高的准确度acc只有80%左右。但是,如果使用迁移学习的方法,以30个epoch为例,最高的准确度可以达到0.974,测试集准确也有0.94。


2.迁移学习的实现

1)自定义模型结构参考代码

train.py

import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from Pokemon import Pokemon
from model import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152

epoch_size = 5
learning_rate = 1e-3
batch_size = 32
resize = 224
root = 'E:\学习\机器学习\数据集\pokemon'

train_data = Pokemon(root=root, resize=resize, mode='train')
val_data = Pokemon(root=root, resize=resize, mode='val')
test_data = Pokemon(root=root, resize=resize, mode='test')

train_loader = DataLoader(train_data, batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size, shuffle=True)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = ResNet50().to(device)
model = ResNet18()
print(model)

crition = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

best_acc = 0
best_epoch = 0

for epoch in range(epoch_size):

    # 训练集训练
    model.train()
    for batchidx, (image, label) in enumerate(test_loader):

        # image = image.to(device)
        # label = label.to(device)

        logits = model(image)
        loss = crition(logits, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batchidx%2 == 0:
            print("epoch:{}/{}, batch:{}/{}, loss:{}"
                  .format(epoch+1, epoch_size, batchidx, len(test_loader), loss))

    # 测试集挑选
    model.eval()
    correct = 0
    for image, label in val_loader:

        # image = image.to(device)
        # label = label.to(device)

        with torch.no_grad():
            logits = model(image)
            pred = logits.argmax(dim=1)

        correct += torch.eq(pred, label).sum().float().item()

    acc = correct/len(val_data)
    print("epoch:{}, acc:{}".format(epoch+1, acc))

    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch

        torch.save(model.state_dict(), 'best.mdl')
        print("[get best epoch]- best_acc:{}, best_epoch:{}".format(best_acc, best_epoch))


复制代码

test.py

import torch
from torch.utils.data import DataLoader
from Pokemon import Pokemon
from model import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152

epoch_size = 5
learning_rate = 1e-3
batch_size = 32
resize = 224
root = 'E:\学习\机器学习\数据集\pokemon'

test_data = Pokemon(root=root, resize=resize, mode='test')
test_loader = DataLoader(test_data, batch_size, shuffle=True)

model = ResNet18()
model.load_state_dict(torch.load('best.mdl'))

# 测试集验证
correct = 0
for image, label in test_loader:

    with torch.no_grad():
        logits = model(image)
        pred = logits.argmax(dim=1)

    correct += torch.eq(pred, label).sum().float().item()

print("len(test_loader):", len(test_data))
acc = correct/len(test_data)
print("final acc:", acc)
复制代码
2)迁移学习模型结构参考代码

(大多数的代码是相同的,主要是模型定义部分的改变)

from torchvision.models import resnet18
from utils import Flatten

# 迁移学习的主要实现
# model = ResNet18()
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],  # torch.Size([32, 512, 1, 1])
                      Flatten(),          # torch.Size([32, 512])
                      nn.Linear(512, 5)   # torch.Size([32, 5])
                      )
model.load_state_dict(torch.load('best.mdl'))
复制代码

utils.py

from    matplotlib import pyplot as plt
import  torch
from    torch import nn

# 打平操作
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

# 显示图像
def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()
复制代码

对于之前的自定义结构,只需要稍微改变了几行代码,准确率就有了大大的提升,验证集也达到了0.94的效果。

所以,为了提高模型的准确率,可以使用一下迁移学习的方式。


猜你喜欢

转载自juejin.im/post/7098149885878206501