用fastai ResNet50训练CIFAR10,85%准确度

版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:[email protected]

Fastai是在pytorch上封装的深度学习框架,效果出众,以下是训练CIFAR10的过程。

导入库

from fastai import *
from fastai.vision import *
from fastai.callbacks import CSVLogger, SaveModelCallback

验证集上训练结果计算和显示

def show_result(learn):
    # 得到验证集上的准确度
    probs, val_labels = learn.get_preds(ds_type=DatasetType.Valid)
    print('Accuracy', accuracy(probs, val_labels)),
    print('Error Rate', error_rate(probs, val_labels))

训练结果混淆矩阵及预测错误最多的类型显示

def show_matrix(learn):
# 画训练结果的混合矩阵
interp = ClassificationInterpretation.from_learner(learn)
interp.confusion_matrix()
interp.plot_confusion_matrix(dpi=120)

# 显示判断错误最多的类型,min_val指定错误次数,默认1
# 打印顺序为actual, predicted, number of occurrences.
interp.most_confused(min_val=5)

# 模型预测最困难的9个样本显示
# 显示顺序为预测值、实际值、损失值、预测对的概率
interp.plot_top_losses(9, figsize=(10, 10))

下载数据集,因调用linux的tar进行解压,在windows下会出错,可手动解压,解压后的目录:

# 下载数据集
untar_data(URLs.CIFAR)

# 训练数据目录
path = Path(r'C:\Users\Administrator\.fastai\data\cifar10')

定义数据及数据在线增强方式

# 数据在线增强方式定义
tfms = get_transforms(do_flip=False)

data = (ImageList.from_folder(path)  # Where to find the data? -> in path and its subfolders
        .split_by_rand_pct()  # How to split in train/valid? -> use the folders
        .label_from_folder()  # How to label? -> depending on the folder of the filenames
        .add_test_folder()  # Optionally add a test set (here default name is test)
        .transform(tfms, size=(32, 32))  # Data augmentation? -> use tfms with a size of 164
        .databunch(bs=128)  # Finally? -> use the defaults for conversion to ImageDataBunch
        .normalize(imagenet_stats))

 查看数据

# 查看数据信息
data.classes, data.c, data
(['airplane',
  'automobile',
  'bird',
  'cat',
  'deer',
  'dog',
  'frog',
  'horse',
  'ship',
  'truck'],
 10,
 ImageDataBunch;
 
 Train: LabelList (39072 items)
 x: ImageList
 Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
 y: CategoryList
 airplane,airplane,airplane,airplane,airplane
 Path: C:\Users\Administrator\.fastai\data\cifar10;
 
 Valid: LabelList (9767 items)
 x: ImageList
 Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
 y: CategoryList
 airplane,deer,deer,deer,automobile
 Path: C:\Users\Administrator\.fastai\data\cifar10;
 
 Test: LabelList (10000 items)
 x: ImageList
 Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
 y: EmptyLabelList
 ,,,,
 Path: C:\Users\Administrator\.fastai\data\cifar10)

 创建训练器

# 创建learn
learn = cnn_learner(data, models.resnet50, metrics=[accuracy, error_rate], callback_fns=[ShowGraph, SaveModelCallback])

 第一阶段训练

# 最佳学习率寻找
learn.lr_find(end_lr=1)
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

# 画出学习率寻找曲线,给出建议学习率
learn.recorder.plot(suggestion=True)

# 根据学习率曲线得到max_lr,开始训练
learn.fit_one_cycle(cyc_len=15, max_lr=1.78e-2)
epoch train_loss valid_loss accuracy error_rate time
0 1.074162 0.882136 0.709225 0.290775 01:17
1 0.824112 0.766163 0.740453 0.259547 01:16
2 0.811090 0.938345 0.707792 0.292208 01:16
3 0.799450 0.790665 0.733797 0.266203 01:16
4 0.763200 1.364758 0.752636 0.247364 01:18
5 0.693490 0.683559 0.776902 0.223098 01:16
6 0.673621 0.611799 0.800655 0.199345 01:16
7 0.665126 0.630715 0.796150 0.203850 01:16
8 0.612187 0.874567 0.826149 0.173851 01:16
9 0.563634 0.785189 0.820723 0.179277 01:16
10 0.515540 1.286271 0.829835 0.170165 01:21
11 0.485959 0.524455 0.840688 0.159312 01:16
12 0.444417 0.759944 0.842736 0.157264 01:17
13 0.419838 0.830482 0.845500 0.154500 01:17
14 0.421095 0.550606 0.844783 0.155217 01:16
 
 
Better model found at epoch 0 with val_loss value: 0.8821364045143127.
Better model found at epoch 1 with val_loss value: 0.7661632299423218.
Better model found at epoch 5 with val_loss value: 0.6835585832595825.
Better model found at epoch 6 with val_loss value: 0.6117991805076599.
Better model found at epoch 11 with val_loss value: 0.5244545340538025.

  训练结果

# 计算和显示训练结果
show_result(learn)

 Accuracy tensor(0.8407)

Error Rate tensor(0.1593)

# 保存训练模型
learn.save('stg1')

 第二阶段训练

learn.load('stg1')
learn.unfreeze()
learn.lr_find(end_lr=1)

 LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

learn.recorder.plot(suggestion=True)

learn.fit_one_cycle(15, slice(1e-6, 5e-5))
epoch train_loss valid_loss accuracy error_rate time
0 0.444569 0.521828 0.840381 0.159619 01:26
1 0.427062 0.513434 0.840483 0.159517 01:27
2 0.430344 0.514867 0.846524 0.153476 01:23
3 0.421480 0.550527 0.845295 0.154705 01:23
4 0.410170 0.506949 0.847855 0.152145 01:23
5 0.402150 0.542091 0.849186 0.150814 01:26
6 0.387639 0.491120 0.850927 0.149073 01:27
7 0.373022 0.511580 0.852155 0.147845 01:28
8 0.375497 0.505493 0.854101 0.145899 01:28
9 0.355466 0.585425 0.852462 0.147538 01:28
10 0.355327 0.506402 0.855534 0.144466 01:28
11 0.341208 0.498502 0.855944 0.144057 01:29
12 0.347057 0.549146 0.851746 0.148254 01:28
13 0.345185 0.533962 0.852155 0.147845 01:28
14 0.334336 0.504231 0.855432 0.144568 01:29
 
 
Better model found at epoch 0 with val_loss value: 0.5218283534049988.
Better model found at epoch 1 with val_loss value: 0.5134344696998596.
Better model found at epoch 4 with val_loss value: 0.5069490671157837.
Better model found at epoch 6 with val_loss value: 0.491120308637619.

训练结果
# 计算和显示训练结果
show_result(learn)
Accuracy tensor(0.8509)
Error Rate tensor(0.1491)

保存模型
learn.save('stg2')

  

# 画训练结果的混合矩阵
interp = ClassificationInterpretation.from_learner(learn)
interp.confusion_matrix()
interp.plot_confusion_matrix(dpi=120)

显示预测错误次数最多的类型,错误次数大于5,输出顺序为actual, predicted, number of occurrences.

interp.most_confused(5)
[('bird', 'frog', 86),
 ('truck', 'automobile', 71),
 ('deer', 'frog', 66),
 ('dog', 'bird', 59),
 ('airplane', 'ship', 57),
 ('bird', 'airplane', 54),
 ('dog', 'frog', 54),
 ('bird', 'deer', 53),
 ('dog', 'deer', 50),
 ('cat', 'dog', 47),
 ('deer', 'bird', 47),
 ('automobile', 'truck', 45),
 ('ship', 'airplane', 45),
 ('cat', 'frog', 44),
 ('bird', 'dog', 37),
 ('ship', 'automobile', 34),
 ('ship', 'truck', 32),
 ('airplane', 'bird', 31),
 ('deer', 'dog', 26),
 ('frog', 'bird', 25),
 ('dog', 'cat', 24),
 ('dog', 'horse', 24),
 ('airplane', 'automobile', 23),
 ('horse', 'deer', 23),
 ('airplane', 'truck', 22),
 ('airplane', 'deer', 20),
 ('frog', 'deer', 17),
 ('cat', 'deer', 16),
 ('horse', 'dog', 14),
 ('automobile', 'ship', 13),
 ('deer', 'horse', 13),
 ('truck', 'ship', 13),
 ('bird', 'ship', 12),
 ('cat', 'bird', 12),
 ('deer', 'airplane', 12),
 ('dog', 'truck', 12),
 ('truck', 'airplane', 12),
 ('frog', 'dog', 11),
 ('airplane', 'frog', 10),
 ('deer', 'ship', 10),
 ('dog', 'airplane', 9),
 ('frog', 'automobile', 8),
 ('horse', 'frog', 8),
 ('ship', 'bird', 8),
 ('cat', 'truck', 7),
 ('horse', 'airplane', 7),
 ('horse', 'bird', 7),
 ('ship', 'deer', 7),
 ('dog', 'automobile', 6),
 ('truck', 'frog', 6),
 ('automobile', 'frog', 5),
 ('bird', 'cat', 5),
 ('bird', 'truck', 5),
 ('cat', 'ship', 5),
 ('dog', 'ship', 5),
 ('frog', 'airplane', 5)]

预测最困难的9个样本

# 模型预测最困难的9个样本显示
# 显示顺序为预测值、实际值、损失值、预测对的概率
interp.plot_top_losses(9, figsize=(10, 10))

猜你喜欢

转载自www.cnblogs.com/zhengbiqing/p/10923342.html