文章介绍
阅读本文时建议配合【自定义数据集-Pokenom Go_完整项目_CodingPark编程公园】一同阅读,因为本迁移学习加载的数据集为Pokemon中定义的数据集
TransferLearning
什么是迁移学习??
把这四个字拆成两个词:迁移➕学习
我们通过一个直观的例子来说明什么是迁移学习。假设你穿越到了古代,成为了太子,为了治理好国家,你需要知道的实在太多了。若是从头学起,肯定是来不及的。你要做的是找你的皇帝老爸,问问他过往的经验,而他也希望能将他脑子的知识一股脑的转移到你脑中。这正是迁移。即将一个领域的已经成熟的知识应用到其他的场景中。
再来
假设你已经有了一个可以高精确度分辨猫和狗的深度神经网络,你之后想训练一个能够分别不同品种的狗的图片模型,你需要做的不是从头训练那些用来分辨直线,锐角的神经网络的前几层,而是利用训练好的网络,提取初级特征,之后只训练最后几层神经元,让其可以分辨狗的品种,这正是学习
所以
迁移学习相当于让新一代的神经网络可以站在前人的基础上更进一步,而不必重新发明轮子。使用一个由他人预先训练好,应用在其他领域的网络,可以作为我们训练模型的起点。不论是有监督学习,无监督学习还是强化学习,迁移学习的概念都有广泛的应用。
举图像识别中最常见的例子,训练一个神经网络。来识别不同的品种的猫,你若是从头开始训练,你需要百万级的带标注数据,海量的显卡资源。而若是使用迁移学习,你可以使用Google发布的Inception或VGG16这样成熟的物品分类的网络,只训练最后的softmax层,你只需要几千张图片,使用普通的CPU就能完成,而且模型的准确性不差。
⚠️需要注意的是:预训练的神经网络和当前的任务差距大的话 -> 迁移学习的效果会很差
完整代码
——————————————————————————train_transfer.py———达到97%准确率———————————————————————
import torch
from torch import optim, nn
import visdom
import torchvision
import time
from torchvision.models import resnet18
from pokemon import Pokemon
from torch.utils.data import DataLoader
from utils import Flatten
# from resnet import ResNet18
batchsz = 32
lr = 0.001
epochs = 10
torch.manual_seed(1234)
'''
加载数据集
'''
train_db = Pokemon('pokemon', 224, mode='train')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
val_db = Pokemon('pokemon', 224, mode='val')
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_db = Pokemon('pokemon', 224, mode='test')
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=4)
viz = visdom.Visdom() # 创建visdom工具
def evaluate(model, loader): # val
model.eval()
print('----------val 运行----------')
correct = 0
total = len(loader.dataset)
for x, label in loader:
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, label).sum().float().item()
return correct / total
def evaluate_test(model, loader): # val
model.eval()
print('----------Test visdom显示----------')
print()
print('NameList --->',test_db.name2label)
print()
correct = 0
total = len(loader.dataset)
for x, label in loader:
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
viz.images(test_db.denormalize(x), nrow=8, win='test_sample_img', opts=dict(title='test_sample_img'))
viz.text(str(pred.numpy()), win='test_sample_text', opts=dict(title='test_sample_text'))
time.sleep(10)
correct += torch.eq(pred, label).sum().float().item()
print('----------Test visdom显示完毕----------')
return correct / total
def main():
print('----------Train 训练----------')
print()
print('----------Train visdom显示----------')
for x, label in train_loader:
viz.images(train_db.denormalize(x), nrow=8, win='trian_sample_img', opts=dict(title='trian_sample_img'))
viz.text(str(label.numpy()), win='trian_sample_text', opts=dict(title='trian_sample_text'))
time.sleep(2)
print()
print('----------Train visdom显示完毕----------')
# model = ResNet18(5) # 最后分为5类
trained_mode = resnet18(pretrained=True)
model = nn.Sequential(
*list(trained_mode.children())[:-1],
Flatten(),
nn.Linear(512,5)
)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss')) # 对应参数是(y, x)的顺序
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc')) # 初始化
for epoch in range(epochs):
model.train()
for step, (x,label) in enumerate(train_loader):
logits = model(x)
loss = criteon(logits, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append') # update viz.line 的 win='loss'
global_step += 1
if epoch % 1 == 0:
val_acc = evaluate(model, val_loader)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
torch.save(model.state_dict(), 'best.mdl') # 保存最好的模型
viz.line([val_acc], [global_step], win='val_acc', update='append') # update viz.line 的 win='val_acc'
print()
print('best acc', best_acc, 'best_epoch', best_epoch)
model.load_state_dict(torch.load('best.mdl'))
# print('loaded from ckpt => model.load_state_dict')
test_acc = evaluate_test(model, test_loader)
print()
print('test_acc', test_acc)
if __name__ == '__main__':
main()
结果展示