使用PyTorch执行特征提取和微调的迁移学习来进行图像分类

这篇博客将介绍如何使用PyTorch深度学习库执行图像分类的转移学习。
① 通过特征提取执行迁移学习
② 通过微调执行迁移学习

第①种方法通常更容易实现,在某些情况下效果很好。然而,它往往不如第二种方法准确。即模型的准确性和泛化能力都会受到影响。大多数形式的迁移学习都采用②微调。

通常建议使用特征提取方法来获得基线精度。如果准确度足以满足那就太棒了!然而,如果精度不够,那么应该进行微调,看看是否可以提高精度。
无论式通过特征提取还是微调的迁移学习,都会为你节省大量的时间和精力,而不是从头开始训练模型。

1. 效果图

5种花朵数据集,分别为雏菊、蒲公英、玫瑰、向日葵、郁金香,效果图如下:
在这里插入图片描述

特征提取效果图如下:
在这里插入图片描述
在这里插入图片描述

微调效果图如下:
在这里插入图片描述

在这里插入图片描述

2 项目结构

pip install opencv-contrib-python
pip install torch torchvision
pip install imutils matplotlib tqdm
  • 用于存储重要变量的配置脚本
  • 数据集加载器辅助函数
  • 在磁盘上构建和组织数据集的脚本,例如PyTorch的ImageFolder 和数据加载器类可以很容易地利用
  • 通过特征提取执行基本迁移学习的驱动程序脚本
  • 第二个驱动程序脚本,通过用全新的、新初始化的FC头替换预训练网络的完全连接(FC)层头来执行微调
  • 一个最终脚本,允许我们使用经过训练的模型进行推理

3 什么是迁移学习

从头开始训练卷积神经网络带来了许多挑战,最显著的是训练网络的数据量和进行训练所需的时间。

迁移学习是一种技术,允许使用为特定任务训练的模型作为不同任务的机器学习模型的起点。
例如,假设在ImageNet数据集上对模型进行图像分类训练。在这种情况下可以采用这个模型并“重新训练”它来识别它最初从未被训练来识别的类!
想象一下,你知道如何骑自行车,却想骑摩托车。你骑自行车的经验——保持平衡、保持方向、转弯和刹车——将帮助你更快地学会骑摩托车。这就是迁移学习在CNN的情况下所做的。使用迁移学习可以通过冻结参数、更改输出层和微调权重来直接使用训练有素的模型。
本质上可以缩短整个训练过程,并在很短的时间内获得高精度的模型。

4 如何使用PyTorch进行迁移学习?

迁移学习主要有两种类型:

  • 通过特征提取进行迁移学习(Transfer learning via feature extraction):从预先训练的网络中移除FC(Fully Connection)层头,并用softmax分类器替换它。这种方法非常简单,因为它允许将预先训练的CNN视为特征提取器,然后将这些特征通过Logistic回归分类器。

  • 通过微调进行迁移学习(Transfer learning via fine-tuning):在应用微调时,再次从预先训练的网络中移除FC(Fully Connection)层头,但这次构建了一个全新的、新初始化的FC层头并将其放置在网络的原始主体上。CNN主体中的权重被冻结,然后训练新的层头(通常具有非常小的学习率)。然后可以选择解冻网络的主体并训练整个网络。

第一种方法往往更容易使用,因为涉及的代码更少,需要调整的参数也更少。然而,第二种方法往往更准确,导致模型更好地推广。通过特征提取和微调的迁移学习都可以用PyTorch实现——我将在本教程的其余部分向您展示如何实现。

5 花朵数据集

将用于微调实验的数据集是由TensorFlow开发团队策划的花朵图像数据集。该数据集的3670幅图像属于五种不同的花卉:

  • 雏菊:633张图片
  • 蒲公英:898张图片
  • 玫瑰:641张图片
  • 向日葵:699张图片
  • 郁金香:799张图片

目标是训练一个图像分类模型来识别这些花的每一种,将通过PyTorch应用迁移学习来实现这一目标。

源码

train_feature_extraction.py

# flower_photos: 5种花朵原始图片集
# config.py 配置文件将存储驱动程序脚本中使用的重要变量和参数。与其在每个脚本中重新定义它们只需在这里定义一次(从而使代码更干净、更容易阅读)
# create_dataloader.py help函数,Dataloader加载flower_photos
# output/ 存放训练损失图
# build_dataset.py 根据flower_photos目录构建数据集目录,将创建特殊的子目录来存储训练和验证拆分,允许PyTorch的ImageFolder脚本来解析目录并训练模型
# train_feature_extraction.py 执行特征提取的迁移学习,并把模型存储磁盘
# fine_tune.py 执行基于微调的迁移学习,并把模型存储磁盘
# inference.py 接受经过训练的PyTorch模型,并使用它对输入的花朵图像进行预测

# 要实现的第一种迁移学习方法是特征提取
# 通过特征提取进行迁移学习的工作原理如下:
# 采用预先训练的CNN(通常在ImageNet数据集上),从CNN上卸下FC(Fully Connection)层头,将网络主体的输出视为空间维度为M×N×C的任意特征提取器
# 分类器有俩个选择:
# 采用标准的逻辑回归分类器(如scikit学习库中的分类器),并根据每个图像中提取的特征对其进行训练。或者,更简单地说,将softmax分类器放在网络主体的顶部,
# 任何一种选择都是可行的,而且或多或少与另一种“相同”。
# 当提取的特征数据集适合机器的RAM时,第一个选项非常有效。这样可以加载整个数据集,实例化逻辑回归分类器模型的一个实例,然后对其进行训练。
# 当数据集太大而无法放入机器内存时,就会出现问题。当这种情况发生时,你可以使用类似在线学习的方法来训练你的逻辑回归分类器,但这只是引入了另一组库和依赖项。
# 相反,更容易的是利用PyTorch的强大功能,在提取的特征之上创建一个类似逻辑回归的分类器,然后使用PyTorch函数对其进行训练。

# 训练特征提取模型,执行该脚本后,将在输出目录中找到一个名为warmup_model.pth的文件——该文件是序列化PyTorch模型,然后可以用于在inference.py脚本中进行预测。
# 总的训练时间只有5分钟多一点,获得了84.26%的训练准确率和87.74%的验证准确率。
# USAGE
# python train_feature_extraction.py

# 导入必要的包
from pyimagesearch import config
from pyimagesearch import create_dataloaders # 从输入数据集目录创建PyTorch DataLoader的实例
from imutils import paths
from torchvision.models import resnet50 # 要使用的ImageNet的预训练模型
from torchvision import transforms # 允许定义一组预处理和/或数据增强,将依次应用于输入图像
from tqdm import tqdm # 用于创建格式良好的进度条的Python库
from torch import nn # 包含PyTorch的神经网络类和函数
import matplotlib.pyplot as plt
import numpy as np
import torch # 包含PyTorch的神经网络类和函数
import time

# 定义增强管道(使用Compose函数构建数据处理/扩充步骤,该函数位于PyTorch的transforms子模块中。
# 首先创建一个trainTransform,在给定输入图像的情况下,它将:
# 随机调整图像大小并将其裁剪为image_SIZE尺寸
# 随机执行水平翻转
# 在[-90,90]范围内随机执行旋转
# 将生成的图像转换为PyTorch张量
# 执行平均值减法和缩放,同样的用于验证数据集的 valTransform
# 请注意,我们不在验证转换器中执行数据扩充——没有必要对验证数据执行数据扩充。
# 创建了训练和验证Compose对象后,让我们应用get_datalader函数:)
trainTansform = transforms.Compose([
    transforms.RandomResizedCrop(config.IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(90),
    transforms.ToTensor(),
    transforms.Normalize(mean=config.MEAN, std=config.STD)
])
valTransform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=config.MEAN, std=config.STD)
])

# 创建DataLoader
(trainDS, trainLoader) = create_dataloaders.get_dataloader(config.TRAIN,
                                                           transforms=trainTansform,
                                                           batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE)
(valDS, valLoader) = create_dataloaders.get_dataloader(config.VAL,
                                                       transforms=valTransform,
                                                       batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)

# 通过特征提取为迁移学习准备ResNet50模型
# 加载预训练的ImageNet ResNet50 model
model = resnet50(pretrained=True)

# 由于使用ResNet50模型作为特征提取器,设置其参数为不可训练(默认情况下是可训练的)
for param in model.parameters():
    param.requires_grad = False

# 将一个新的分类顶部附加到我们的特征提取器并弹出它,连接到当前设备
# 创建一个由单个FC层组成的新FC层头。实际上当使用分类交叉熵损失进行训练时,这一层将作为代理softmax分类器。
# 然后,这个新层被附加到网络主体,模型本身被移动到设备(CPU或GPU)。
modelOutputFeats = model.fc.in_features
model.fc = nn.Linear(modelOutputFeats, len(trainDS.classes))
model = model.to(config.DEVICE)

# 接下来,初始化损失函数和优化方法(注意只是向优化器提供分类顶部的参数)
lossFunc = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.fc.parameters(), lr=config.LR)

# 计算训练/验证集的每一个纪元步数
trainSteps = len(trainDS) // config.FEATURE_EXTRACTION_BATCH_SIZE
valSteps = len(valDS) // config.FEATURE_EXTRACTION_BATCH_SIZE

# 初始化字典以存储训练历史
H = {
    
    "train_loss": [], "train_acc": [], "val_loss": [],
     "val_acc": []}

# 遍历纪元
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.EPOCHS)):
    # 设置模型训练模式
    model.train()

    # 初始化训练/验证损失
    totalTrainLoss = 0
    totalValLoss = 0

    # 初始化训练/验证集中的预测正确个数
    trainCorrect = 0
    valCorrect = 0

    # 遍历训练集
    # 对于trainLoader中的每一批数据,将图像和类标签移动到CPU/GPU、对数据进行预测、计算损失,计算梯度,更新模型权重,并将梯度归零
    # 累积在该时期的总训练损失、计算正确预测的总数
    for (i, (x, y)) in enumerate(trainLoader):
        # 传递输入到设备
        (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))

        # 向前传递并计算训练损失
        pred = model(x)
        loss = lossFunc(pred, y)

        # 计算损失梯度
        loss.backward()

        # 检查是否正在更新模型参数,如果是 更新它们,并将之前累积的梯度清零
        if (i + 2) % 2 == 0:
            opt.step()
            opt.zero_grad()

        # 将损失加上迄今为止的总训练损失,同样累加正确预测的数量
        totalTrainLoss += loss
        trainCorrect += (pred.argmax(1) == y).type(
            torch.float).sum().item()

        # 关闭autograd并将模型置于评估模式中——这是使用PyTorch进行评估时的要求
        # switch off autograd
        with torch.no_grad():
            # 设置模型为评估模式
            model.eval()

            # 在valLoader中循环所有数据点,对它们进行预测,并计算总损失和正确验证预测的数量。
            # 遍历验证集
            for (x, y) in valLoader:
                # 传递输入到设备
                (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))

                # 预测并计算验证损失
                pred = model(x)
                totalValLoss += lossFunc(pred, y)

                # 计算正确预测的数量
                valCorrect += (pred.argmax(1) == y).type(
                    torch.float).sum().item()

        # 以下代码块汇总训练/验证损失和准确性,更新训练历史记录,然后将损失/准确性信息打印到终端
        # 计算平均训练/验证损失
        avgTrainLoss = totalTrainLoss / trainSteps
        avgValLoss = totalValLoss / valSteps

        # 计算训练/验证准确性
        trainCorrect = trainCorrect / len(trainDS)
        valCorrect = valCorrect / len(valDS)

        # 更新训练历史
        H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
        H["train_acc"].append(trainCorrect)
        H["val_loss"].append(avgValLoss.cpu().detach().numpy())
        H["val_acc"].append(valCorrect)

        # 打印模型训练、验证信息
        print("[INFO] EPOCH: {}/{}".format(e + 1, config.EPOCHS))
        print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
            avgTrainLoss, trainCorrect))
        print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(
            avgValLoss, valCorrect))

# 绘制训练历史,序列化模型到磁盘
# 展示训练模型的总耗时
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
    endTime - startTime))

# 绘制训练/验证损失和准确性图
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(config.WARMUP_PLOT)

# 序列化模型到磁盘
torch.save(model, config.WARMUP_MODEL)

fine_tune.py

# ① 通过特征提取执行迁移学习
# ② 通过微调执行迁移学习
#
# ①在某些情况下效果很好,但其简单性也有缺点,即模型的准确性和泛化能力都会受到影响。大多数形式的迁移学习都采用②微调。

# 与特征提取类似,首先从网络中移除FC层头,但这次创建了一个全新的层头,其中包含一组线性、ReLU和丢弃层,类似于您在现代最先进的CNN上看到的内容。
# 然后执行以下组合:
# 冻结网络主体中的所有层并训练层头
# 冻结所有层,训练层头,然后解冻身体并训练
# 只需将所有图层解冻并一起训练即可
# 确切地说,你使用哪种方法是你自己进行的实验——一定要测量哪种方法的损失最小,准确度最高!

# 通过PyTorch的迁移学习应用微调
# 由于模型更为复杂(由于在网络主体中添加了新的FC层头),现在训练需要大约6.5分钟。然而在图4中获得了比简单特征提取方法更高的精度(分别为90.83%/90.19%和84.26%/87.74%)
# 虽然执行微调确实需要更多的工作,但通常会发现精度更高,模型会更好地推广。
# USAGE
# python fine_tune.py

# import the necessary packages
from pyimagesearch import config
from pyimagesearch import create_dataloaders
from imutils import paths
from torchvision.models import resnet50
from torchvision import transforms
from tqdm import tqdm
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import shutil
import torch
import time
import os

# 定义训练和验证转换,和对特征提取所做的相同
# 定义增强管道(使用Compose函数构建数据处理/扩充步骤,该函数位于PyTorch的transforms子模块中。
# 首先创建一个trainTransform,在给定输入图像的情况下,它将:
# 随机调整图像大小并将其裁剪为image_SIZE尺寸
# 随机执行水平翻转
# 在[-90,90]范围内随机执行旋转
# 将生成的图像转换为PyTorch张量
# 执行平均值减法和缩放,同样的用于验证数据集的 valTransform
# 请注意,我们不在验证转换器中执行数据扩充——没有必要对验证数据执行数据扩充。
# 创建了训练和验证Compose对象后,让我们应用get_datalader函数:)
trainTansform = transforms.Compose([
    transforms.RandomResizedCrop(config.IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(90),
    transforms.ToTensor(),
    transforms.Normalize(mean=config.MEAN, std=config.STD)
])
valTransform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=config.MEAN, std=config.STD)
])

# 创建DataLoader
(trainDS, trainLoader) = create_dataloaders.get_dataloader(config.TRAIN,
                                                           transforms=trainTansform,
                                                           batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE)
(valDS, valLoader) = create_dataloaders.get_dataloader(config.VAL,
                                                       transforms=valTransform,
                                                       batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)


# 加载ResNet模型,其中包含在ImageNet数据集上预先训练的权重。在这个微调中将构建一个新的FC层头,然后同时训练FC层头和网络主体。
# 然而首先需要密切关注网络架构中的批处理规范化层。这些层具有特定的平均值和标准偏差值,这些值是在最初在ImageNet数据集上训练网络时获得的。
# 不想在训练期间更新这些统计数据,冻结了BatchNorm2d的任何实例。构建新的headModel,它由一系列FC=>RELU=>DROPOUT层组成,最后一个线性层的输出是数据集中的类的数量,最后将新的headModel添加到网络中,从而替换旧的FC层头。
# 加载预训练的ImageNet ResNet50 model
# 真正的变化来自于从磁盘加载ResNet并修改体系结构本身
model = resnet50(pretrained=True)
numFeatures = model.fc.in_features

# 遍历模型的模块,设置批量归一化为非训练状态
for module, param in zip(model.modules(), model.parameters()):
    if isinstance(module, nn.BatchNorm2d):
        param.requires_grad = False

# 定义网络头,添加到模型
headModel = nn.Sequential(
    nn.Linear(numFeatures, 512),
    nn.ReLU(),
    nn.Dropout(0.25),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, len(trainDS.classes))
)
model.fc = headModel

# 将一个新的分类顶部附加到微调模型并连接到当前设备
model = model.to(config.DEVICE)

# 初始化损失函数和优化方法(注意只是向优化器提供分类顶部的参数)
lossFunc = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=config.LR)

# 计算训练/验证集的每一个纪元步数
trainSteps = len(trainDS) // config.FINETUNE_BATCH_SIZE
valSteps = len(valDS) // config.FINETUNE_BATCH_SIZE

# 初始化字典以存储训练历史
H = {
    
    "train_loss": [], "train_acc": [], "val_loss": [],
     "val_acc": []}

# 遍历纪元
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.EPOCHS)):
    # 设置模型为训练模式
    model.train()

    # 初始化总训练和验证损失
    totalTrainLoss = 0
    totalValLoss = 0

    # 初始化训练/验证的正确预测数
    trainCorrect = 0
    valCorrect = 0

    # 遍历训练集
    for (i, (x, y)) in enumerate(trainLoader):
        # 将输入图像及标签传递给设备
        (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))

        # 向前传递并计算训练损失
        pred = model(x)
        loss = lossFunc(pred, y)

        # 计算梯度
        loss.backward()

        # 检查是否正在更新模型参数,如果是 更新它们,并将之前累积的梯度清零
        if (i + 2) % 2 == 0:
            opt.step()
            opt.zero_grad()

        # 将损失加上迄今为止的总训练损失,同样累加正确预测的数量
        totalTrainLoss += loss
        trainCorrect += (pred.argmax(1) == y).type(
            torch.float).sum().item()

        # 关闭autograd并将模型置于评估模式中——这是使用PyTorch进行评估时的要求
        # switch off autograd
        with torch.no_grad():
            # 设置模型为评估模式
            model.eval()

            # 在valLoader中循环所有数据点,对它们进行预测,并计算总损失和正确验证预测的数量。
            # 遍历验证集
            for (x, y) in valLoader:
                # 把输入传递到模型
                (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))

                # 预测及计算验证损失
                pred = model(x)
                totalValLoss += lossFunc(pred, y)

                # 计算正确预测数
                valCorrect += (pred.argmax(1) == y).type(
                    torch.float).sum().item()

        # 计算训练和验证平均损失
        avgTrainLoss = totalTrainLoss / trainSteps
        avgValLoss = totalValLoss / valSteps

        # 计算训练和验证精度
        trainCorrect = trainCorrect / len(trainDS)
        valCorrect = valCorrect / len(valDS)

        # 更新训练历史值
        H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
        H["train_acc"].append(trainCorrect)
        H["val_loss"].append(avgValLoss.cpu().detach().numpy())
        H["val_acc"].append(valCorrect)

        # 打印训练和验证信息
        print("[INFO] EPOCH: {}/{}".format(e + 1, config.EPOCHS))
        print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
            avgTrainLoss, trainCorrect))
        print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(
            avgValLoss, valCorrect))

# 打印训练的最终耗时
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
    endTime - startTime))

# 绘制训练和损失精确度图
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(config.FINETUNE_PLOT)

# 序列化模型到磁盘
torch.save(model, config.FINETUNE_MODEL)

inference.py

# 使用PyTorch应用迁移学习的两种方法:特征提取、微调,这两种方法都使模型获得了80-90%的准确率

# 微调迁移学习能获得更好的结果

# USAGE
# python inference.py --model output/warmup_model.pth
# python inference.py --model output/finetune_model.pth

import argparse  # 解析命令行参数

import matplotlib.pyplot as plt  # 绘制输出图像及预测结果
import torch  # PyTorch绑定函数及方法
from torchvision import transforms  # 通过顺序的方式执行一系列数据预处理

# 导入必要的包
from pyimagesearch import config  # 全局配置文件
from pyimagesearch import create_dataloaders  # 帮助函数以根据图像目录创建DataLoader对象得到dataset/val文件夹

# 构建命令行参数及解析
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=False, default="output/warmup_model.pth",
                help="path to trained model model")
args = vars(ap.parse_args())

# 创建 数据预处理管道
# 调整图像大小并将其裁剪为IMAGE_SIZE尺寸
# 将生成的图像转换为PyTorch张量
# 执行平均值缩放
testTransform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=config.MEAN, std=config.STD)
])

# 计算反均值和标准差 calculate the inverse mean and standard deviation
invMean = [-m / s for (m, s) in zip(config.MEAN, config.STD)]
invStd = [1 / s for s in config.STD]

# 定义去归一化变换 define our de-normalization transform 以展示图片到屏幕
deNormalize = transforms.Normalize(mean=invMean, std=invStd)

# 初始化测试集和DataLoader
print("[INFO] loading the dataset...")
(testDS, testLoader) = create_dataloaders.get_dataloader(config.VAL,
                                                         transforms=testTransform, batchSize=config.PRED_BATCH_SIZE,
                                                         shuffle=True)

# 检查是否有可用GPU,如果是,定义相应地图位置
if torch.cuda.is_available():
    map_location = lambda storage, loc: storage.cuda()

# 否则,使用cpu训练模型
else:
    map_location = "cpu"

# 加载模型
print("[INFO] loading the model...")
model = torch.load(args["model"], map_location=map_location)

# 设置模型为cpu/gpu训练,设置为评估模式
model.to(config.DEVICE)
model.eval()

# 获取一批测试数据集
batch = next(iter(testLoader))
(images, labels) = (batch[0], batch[1])

# 初始化图像
fig = plt.figure("Results", figsize=(10, 10))

# switch off autograd
with torch.no_grad():
    # 把图像传递到设备
    images = images.to(config.DEVICE)

    # 执行预测
    print("[INFO] performing inference...")
    preds = model(images)

    # 遍历所有批次
    for i in range(0, config.PRED_BATCH_SIZE):
        # 初始化一个子图以绘制图像和预测结果
        # ax = plt.subplot(config.PRED_BATCH_SIZE, 1, i + 1) # 4行1列
        ax = plt.subplot(config.PRED_BATCH_SIZE / 2, 2, i + 1)  # 2行2列

        # 获取图像,反归一化,缩放原始像素为[0,255] 并更改通道,从第一个通道到最后一个通道排序
        # 通过“撤销”平均缩放和交换颜色通道顺序来取消图像的标准化
        image = images[i]
        image = deNormalize(image).cpu().numpy()
        image = (image * 255).astype("uint8")
        image = image.transpose((1, 2, 0))

        # 获取正确的标签
        idx = labels[i].cpu().numpy()
        gtLabel = testDS.classes[idx]

        # 获取预测的标签
        pred = preds[i].argmax().cpu().numpy()
        predLabel = testDS.classes[pred]

        # 添加真实标签及预测标签到图像上
        info = "Ground Truth: {}, Predicted: {}".format(gtLabel,
                                                        predLabel)
        plt.imshow(image)
        plt.title(info)
        plt.axis("off")

    # 展示
    plt.tight_layout()
    plt.show()

参考

猜你喜欢

转载自blog.csdn.net/qq_40985985/article/details/131207024