注意力机制——Spatial Transformer Networks(STN)

Spatial Transformer Networks(STN)是一种空间注意力模型,可以通过学习对输入数据进行空间变换,从而增强网络的对图像变形、旋转等几何变换的鲁棒性。STN 可以在端到端的训练过程中自适应地学习变换参数,无需人为设置变换方式和参数。

STN 的基本结构包括三个部分:定位网络(Localization Network)、网格生成器(Grid Generator)和采样器(Sampler)。定位网络通常由卷积层、全连接层和激活函数构成,用于学习输入数据的空间变换参数。网格生成器用于生成采样网格,采样器则根据采样网格对输入数据进行采样。整个 STN 模块可以插入到任意位置,用于提高网络的对图像变形、旋转等几何变换的鲁棒性。

在 STN 中,定位网络通常由一个多层感知器(MLP)和一些辅助层(如卷积层、全连接层和激活函数)构成。MLP 的输出用于计算变换参数(如平移、旋转和缩放等),从而生成采样网格。采样器通常由双线性插值、最近邻插值和反卷积等方法实现,用于对输入数据进行采样。

STN 的优点在于,它可以学习对输入数据进行任意复杂的空间变换,从而提高网络的对图像变形、旋转等几何变换的鲁棒性。此外,STN 可以与其他深度学习模型结合使用,从而提高整个系统的性能。例如,在图像分类任务中,可以将 STN 插入到卷积神经网络中,用于对输入图像进行空间变换,增强网络对图像变形、旋转等几何变换的鲁棒性。

STN注意力模块pytorch实现:

import torch
import torch.nn as nn
import torch.nn.functional as F

class STN(nn.Module):
    def __init__(self):
        super(STN, self).__init__()
        # 定义本地化网络,用于估计空间变换的参数
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7), # 输入通道数为 1,输出通道数为 8,卷积核大小为 7
            nn.MaxPool2d(2, stride=2), # 最大池化层,核大小为 2,步长为 2
            nn.ReLU(True), # ReLU 激活函数
            nn.Conv2d(8, 10, kernel_size=5), # 输入通道数为 8,输出通道数为 10,卷积核大小为 5
            nn.MaxPool2d(2, stride=2), # 最大池化层,核大小为 2,步长为 2
            nn.ReLU(True) # ReLU 激活函数
        )
        # 定义空间变换网络,用于预测空间变换的参数
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32), # 全连接层,输入维度为 10 * 3 * 3,输出维度为 32
            nn.ReLU(True), # ReLU 激活函数
            nn.Linear(32, 3 * 2) # 全连接层,输入维度为 32,输出维度为 3 * 2
        )
        # 初始化空间变换网络的权重和偏置
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    def forward(self, x):
        # 使用本地化网络对输入图像进行特征提取
        xs = self.localization(x)
        # 将特征张量展开成一维张量
        xs = xs.view(-1, 10 * 3 * 3)
        # 使用空间变换网络预测空间变换的参数
        theta = self.fc_loc(xs)
        # 将一维张量转换成二维张量,用于执行仿射变换
        theta = theta.view(-1, 2, 3)
        # 使用仿射变换对输入图像进行空间变换
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        return x

以上代码中,STN 类继承自 PyTorch 的 nn.Module 类,是一个包含了本地化网络和空间变换网络的模块。具体来说,STN 模块包含以下组件:

  • self.localization:本地化网络,用于对输入图像进行特征提取,提取出用于估计空间变换参数的特征向量。
  • self.fc_loc:空间变换网络,用于根据本地化网络提取的特征向量预测空间变换的参数。
  • self.fc_loc[2].weight.data.zero_() 和 self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)):用于初始化空间变换网络的权重和偏置,其中权重矩阵初始化为零矩阵,偏置向量初始化为一个 torch.tensor 对象,其元素为 [1,0,0,0,1,0][1,0,0,0,1,0],表示初始的空间变换为一个单位矩阵。
  • forward 方法:模块的前向传播过程。首先使用本地化网络对输入图像进行特征提取,然后将特征张量展开成一维张量,使用空间变换网络预测空间变换的参数。接着将一维张量转换成二维张量,用于执行仿射变换,并使用仿射变换对输入图像进行空间变换,最后返回变换后的图像张量。

STN模块在模型中添加:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.stn = STN()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        # 使用 STN 对输入图像进行空间变换
        x = self.stn(x)
        # 经过卷积和池化层处理
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

猜你喜欢

转载自blog.csdn.net/weixin_50752408/article/details/129584264
今日推荐