注意力机制——Non-local Networks(NLNet)

Non-local Networks(NLNet):NLNet是一种非局部注意力模型,通过对整个输入空间的特征进行加权求和,以捕捉全局信息。

传统的卷积神经网络(CNN)在处理图像时,只考虑了局部区域内的像素信息,忽略了全局信息之间的相互作用。NLNets通过引入非局部块来解决这个问题,该块包括一个自注意力模块,用于学习像素之间的相互作用。

自注意力模块采用注意力机制来计算每个像素与其他像素之间的相互依赖关系,并使用这些依赖关系来加权聚合所有像素的特征表示。这种全局交互方式使得模型能够在像素之间建立远距离的依赖关系,从而提高了模型的表示能力。

NonLocalBlock模块pytorch实现:

import torch
import torch.nn as nn
import torch.nn.functional as F

class NonLocalBlock(nn.Module):
    def __init__(self, in_channels, inter_channels=None):
        super(NonLocalBlock, self).__init__()
        self.in_channels = in_channels
        self.inter_channels = inter_channels or in_channels // 2

        # 定义 g、theta、phi、out 四个卷积层
        self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
        self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
        self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
        self.out = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0)

        # 定义 softmax 层,用于将 f_ij 进行归一化
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size = x.size(0)

        # 计算 g(x)
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        # 计算 theta(x)
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)

        # 计算 phi(x)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        # 计算 f_ij
        f = torch.matmul(theta_x, phi_x)

        # 对 f_ij 进行归一化
        f_div_C = self.softmax(f)

        # 计算 y_i
        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])

        # 计算 z_i
        y = self.out(y)
        z = y + x

        return z

NonLocalBlock模块在网络中添加:

class NLNet(nn.Module):
    def __init__(self, num_classes=10):
        super(NLNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.nonlocal1 = NonLocalBlock(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.nonlocal2 = NonLocalBlock(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc = nn.Linear(256*4*4, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.nonlocal1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.nonlocal2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = x.view(-1, 256*4*4)
        x = self.fc(x)
        return x

猜你喜欢

转载自blog.csdn.net/weixin_50752408/article/details/129584665