Attention UNet结构及pytorch实现

注意力机制可以说是深度学习研究领域上的一个热门领域,它在很多模型上都有着不错的表现,比如说BERT模型中的自注意力机制。本博客仅作为本人在看了一些Attention UNet相关文章后所作的笔记,希望能给各位带来一点思考,注意力机制是怎么被应用在医学图像分割的。

参考文章:

  1. 【语义分割系列:七】Attention Unet 论文阅读翻译笔记 医学图像 python实现
  2. 医学图像分割-Attention Unet
  3. 如果不知道什么是注意力机制,可以看看这篇博客:浅谈Attention-based Model【原理篇】

Attention UNet网络结构

UNet是一个用于分割领域的架构,自2015年被提出以来,在医学图像领域取得了不错的表现,成为了不少医疗影像语义分割任务的baseline。感兴趣的可以去看一下这一篇博客:Unet神经网络为什么会在医学图像分割表现好?

UNet的网络结构并不复杂,最主要的特点便是U型结构skip-connection。而Attention UNet则是使用了标准的UNet的网络架构,并在这基础上整合进去了Attention机制。更准确来说,是将Attention机制整合进了跳远连接(skip-connection)。
整个网络架构如下, 注意力block已用红色框出:在这里插入图片描述
与标准的UNet相比,整体结构是很相似的,唯一不同的是在红框内增加了注意力门。为了公式化这个过程,我们将跳远连接的输入称为x,来自前一个block的输入称为g,那么整个模块就可以用以下公式来表示了:
在这里插入图片描述

在这个公式里面,Attention就是注意力门,upsample是一个简单上采样模块,采用最近邻插值,而ConvBlock只是由两个(convolution + batch norm + ReLU)块组成的序列。唯一需要解释的是注意力。

接下来让我们看一下整个注意力门是怎么实现的,整个结构图如下:
在这里插入图片描述

整个过程不难理解 ,需要注意一下几点:

  1. x和g都被送入到1x1卷积中,将它们变为相同数量的通道数,而不改变大小
  2. 在上采样操作后(有相同的大小),他们被累加并通过ReLU
  3. 通过另一个1x1的卷积和一个sigmoid,得到一个0到1的重要性分数,分配给特征图的每个部分
  4. 然后用这个注意力图乘以skip输入,产生这个注意力块的最终输出

pytorch实现

下面的代码定义了注意力块(简化版)和用于UNet扩展路径的“up-block”。“down-block”与原UNet一样。

class AttentionBlock(nn.Module):
    def __init__(self, in_channels_x, in_channels_g, int_channels):
        super(AttentionBlock, self).__init__()
        self.Wx = nn.Sequential(nn.Conv2d(in_channels_x, int_channels, kernel_size = 1),
                                nn.BatchNorm2d(int_channels))
        self.Wg = nn.Sequential(nn.Conv2d(in_channels_g, int_channels, kernel_size = 1),
                                nn.BatchNorm2d(int_channels))
        self.psi = nn.Sequential(nn.Conv2d(int_channels, 1, kernel_size = 1),
                                 nn.BatchNorm2d(1),
                                 nn.Sigmoid())
    
    def forward(self, x, g):
        # apply the Wx to the skip connection
        x1 = self.Wx(x)
        # after applying Wg to the input, upsample to the size of the skip connection
        g1 = nn.functional.interpolate(self.Wg(g), x1.shape[2:], mode = 'bilinear', align_corners = False)
        out = self.psi(nn.ReLU()(x1 + g1))
        return out*x

class AttentionUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionUpBlock, self).__init__()
        self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 2, stride = 2)
        self.attention = AttentionBlock(out_channels, in_channels, int(out_channels / 2))
        self.conv_bn1 = ConvBatchNorm(in_channels+out_channels, out_channels)
        self.conv_bn2 = ConvBatchNorm(out_channels, out_channels)
    
    def forward(self, x, x_skip):
        # note : x_skip is the skip connection and x is the input from the previous block
        # apply the attention block to the skip connection, using x as context
        x_attention = self.attention(x_skip, x)
        # upsample x to have th same size as the attention map
        x = nn.functional.interpolate(x, x_skip.shape[2:], mode = 'bilinear', align_corners = False)
        # stack their channels to feed to both convolution blocks
        x = torch.cat((x_attention, x), dim = 1)
        x = self.conv_bn1(x)
        return self.conv_bn2(x)

整个网络架构完整版实现可以参考 【语义分割系列:七】Attention Unet 论文阅读翻译笔记 医学图像 python实现

おすすめ

転載: blog.csdn.net/weixin_41693877/article/details/108395270