深度学习之训练过程中特征图的可视化

训练过程中特征图的可视化

在网络训练的过程中,有时我们想知道网络中某些层输出的特征图到底长啥样,从而能够比较清楚的知道网络在每一层到底学到了哪些有用的特征信息,也能更好的帮助我们设计优秀的网络结构。本文详细介绍了在训练过程中,某些层次特征图的可视化操作。

1、创建模型

这里我们使用预训练好权重的 AlexNet 模型

# 引入alexnet模型及权重
from torchvision.models import alexnet, AlexNet_Weights

# 初始化模型
model = alexnet(weights=AlexNet_Weights.DEFAULT)

# 输出模型信息
print(model)

# 模型的信息如下
# 冒号左边的表示模块名称、右边的表示具体的模块结构
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

2、数据预处理

本文中我打算将玫瑰花作为网络模型的输入,并将其相应的特征图进行可视化。所以,这里先贴出本次的主角。

在这里插入图片描述

import torch
from torchvision import transforms
from PIL import Image

# 加载原图片
img = Image.open("./meigui.png").convert('RGB')
print("图片原尺寸:", img.size)  # 图片原尺寸: (952, 720)

# 定义图像处理过程
transforms = transforms.Compose([
    # 将图片缩放到224*224
    transforms.Resize(224),
    # 转换成tensor,并且将图片归一化
    transforms.ToTensor()
])

# 将图片进行处理
img = transforms(img)
print("图片处理后的尺寸:", img.size())  # 图片处理后的尺寸: torch.Size([3, 224, 224])

# 将单张图片转换成batch形式,这样才符合网络的输入形式
img = torch.unsqueeze(img, dim=0)
print("图片处理后的尺寸:", img.size())  # 图片处理后的尺寸: torch.Size([1, 3, 224, 224])

# 测试前向传播是否正常(可选)
output = model(img)
# 因为alexnet默认是1000分类,所以输出的维度是[1, 1000]
print("网络输出大小:", output.size())  # 网络输出大小:torch.Size([1, 1000]) 

3、获取特定层的输出特征

# 设置需要获取哪些层的输出特征,对应的是AlexNet网络结构中(features)层中的5个卷积层的名字
# 忘记网络结构的,可往上翻一翻
layers_name = ["0", "3", "6", "8", "10"]


# 定义一个函数,用于获取某些层次中的特征图信息
# 参数依次为:模型、输入x、需要获取输出特征的层次名称
def receive_feature_map(model, x, layers_name):
    # 存储需要输出的特征矩阵信息
    outputs = []
    for name, module in model.features.named_children():
        # 依次进行前向传播,计算
        x = module(x)
        # 需要提取输出的卷积层
        if name in layers_name:
            # 将本层的输出结果添加到outputs中
            outputs.append(x)
    return outputs


# 调用函数,得到每一层的特征图输出数据
outputs = receive_feature_map(model, img, layers_name)

4、特征图可视化

有了每一层的输出特征数据,我们就可以利用 matplotlib 来绘制每一层中的所有特征图了,并且还可以将绘制好的特征图保存到 tensorboard 日志文件中,便于后续查看。tensorboard 的详细使用可以参考我的另一篇博客:Tensorboard的详细使用


import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

# 使用tensorboard记录特征图
writer = SummaryWriter("logs")

# 绘制每一层卷积层输出的特征图
for index, feature_maps in enumerate(outputs):
    # 每一层的feature_maps的大小为:[N, C, W, H],分别表示batch、通道数、宽、高
    # 由于这里只有一张图片,所以通过此方法,去掉batch,[N, C, W, H] -> [C, W, H]
    img = torch.squeeze(feature_maps)
    # 将tensor转换成numpy类型
    img = img.detach().numpy()
    # 获取特征图的通道数
    channel_num = img.shape[0]
    # 网络中每一层的输出特征,最多绘制12张特征图
    num = channel_num if channel_num < 12 else 12
    fig = plt.figure()
    # 循环绘制
    for i in range(num):
        plt.subplot(3, 4, i + 1)
        # 依次绘制其中的一个通道,img的size为:[C, W, H]
        plt.imshow(img[i, :, :])
    title = "conv{}".format(index)
    # 设置每一层特征图的标题
    plt.title(title)
    plt.show()
    # 将特征图记录到tensorboard中
    writer.add_figure(title, fig, index)

5、结果

  • matplotlib 中的效果:

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

  • tensorboard 中的效果:

在这里插入图片描述

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/L28298129/article/details/126530266