图像复原的金字塔自注意力网络PANet | Pyramid Attention Networks for Image Restoration

应该是投向ECCV2020的文章,Non-local与多尺度特征结合的第二篇(第一篇见https://blog.csdn.net/weixin_42096202/article/details/106222582),不过本质上是不一样。第一篇论文是在不同尺度特征上分别应用Non-Local然后进行融合,比较机械化。这篇论文是在多个尺度特征上应用一个Non-Local,即挖掘索引像素点与对应多尺度特征的响应,突破了自注意力只能捕获单一尺度的特征依赖缺陷

论文地址:https://arxiv.org/pdf/2004.13824.pdf
Github:https://github.com/SHI-Labs/Pyramid-Attention-Networks

在这里插入图片描述

Abstract:

自相似性是指在图像复原算法中广泛使用的图像先验,在不同的位置和尺度上往往会出现小的但相似的图案。但是,最近的基于深度卷积神经网络的先进图像复原方法依靠仅处理相同尺度信息的自注意神经模块无法充分利用自相似性。为了解决这个问题,我们提出了一种新颖的金字塔注意力模块用于图像复原,该模块从多尺度特征金字塔中捕获远j距离特征对应关系。受到诸如噪声或压缩伪影之类的损坏在较粗糙的图像尺度下急剧下降这一事实的启发,我们的注意力模块被设计为能够从较粗的级别的“干净”对应中借用干净的信号。一个通用的构建块,可以灵活地集成到各种神经体系结构中。通过对多种图像恢复任务的广泛实验来验证其有效性:图像去噪,去马赛克,压缩伪像减少和超分辨率。我们的PANet(金字塔形)具有简单网络骨干的注意力关注模块)可以产生具有卓越准确性和视觉质量的最新结果。

Introduction:

在这里插入图片描述
当前的自注意力机制存在以下问题:

1.如Non-Local模块等都是集中在单一尺度特征提取全局先验。因此未能捕获发生在不同尺度上的有用的特征依赖关系。

2.在自注意力模块中使用的逐像素匹配通常对低级视觉任务很嘈杂,从而降低了性能。 从直觉上讲,扩大搜索空间会增加寻找更好匹配的可能性,但对于现有的自注意模块而言并非如此。 与采用大量降维操作的高级特征图不同,图像复原网络通常会保持输入的空间大小。 因此,特征仅与局部区域高度相关,因此容易受到噪声信号的影响。 这与传统的非局部滤波方法相一致,在传统的非局部滤波方法中,逐像素匹配的效果比块匹配要差得多。

因此本文提出的自注意力机制充分利用了传统的Non-Local操作的优势,但旨在更好地符合图像复原的性质。 特别是,原始搜索空间在很大程度上从单个要素图扩展到了多尺度要素金字塔。

Methods:

1.Scale Agnostic Attention && Pyramid Attention
在这里插入图片描述
如上图所示:

1.图(a)为Non-Local注意力,在单一尺度上捕获像素的全局响应;
2.图(b)为Scale agnostic注意力,可以捕获两个尺度上的全局像素响应;
3.图(c))Pyramid注意力,捕获多个尺度上的全局响应。

在这里插入图片描述
具体实现方式:Pyramid Attention是先提取得到多个尺度特征,然后按照bottom-up的方式逐次对相邻两个尺度特征应用Scale Agnostic Attention实现。其中,Scale Agnostic Attention是以分块匹配的自注意力操作代替Non-Local中的逐像素匹配操作

扫描二维码关注公众号,回复: 11384755 查看本文章

接下来从代码的角度看一下Pyramid Attention的实现方式:
1.使用双三次插值方式构建5个尺度的特征金字塔

2.分别对每个尺度特征提取两次图像块,分别用于重建f与转换的g(用于图像块匹配),分别对应raw_w与w

3.对在不同尺度提取到的块特征w,进行拼接作为核函数权重与输入xi进行卷积匹配,并应用Softmax函数得到自相似性特征图

4.对在不同尺度提取到的块特征raw_w进行拼接作为转置卷积核权重,与自相似性特征图进行反卷积得到输入特征

class PyramidAttention(nn.Module):
    def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True, conv=common.default_conv):
        super(PyramidAttention, self).__init__()
        self.ksize = ksize
        self.stride = stride
        self.res_scale = res_scale
        self.softmax_scale = softmax_scale
        self.scale = [1-i/10 for i in range(level)]
        self.average = average
        escape_NaN = torch.FloatTensor([1e-4])
        self.register_buffer('escape_NaN', escape_NaN)
        self.conv_match_L_base = common.BasicBlock(conv,channel,channel//reduction, 1, bn=False, act=nn.PReLU())
        self.conv_match = common.BasicBlock(conv,channel, channel//reduction, 1, bn=False, act=nn.PReLU())
        self.conv_assembly = common.BasicBlock(conv,channel, channel,1,bn=False, act=nn.PReLU())

    def forward(self, input):
        res = input
        #theta
        match_base = self.conv_match_L_base(input)
        shape_base = list(res.size())
        #按照batch_size进行分离
        input_groups = torch.split(match_base,1,dim=0)
        # patch size for matching 
        kernel = self.ksize
        # raw_w is for reconstruction
        raw_w = []
        # w is for matching
        w = []
        #build feature pyramid
        for i in range(len(self.scale)):    
            ref = input
            #尺度为[1,0.9,0.8,0.7,0.6]
            if self.scale[i]!=1:
                ref  = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic')
            #feature transformation function f
            base = self.conv_assembly(ref)
            shape_input = base.shape
            
            #sampling 取图像块的方式代替逐像素
            raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel],
                                      strides=[self.stride,self.stride],
                                      rates=[1, 1],
                                      padding='same') # [N, C*k*k, L]
            raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1)
            raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3)    # raw_shape: [N, L, C, k, k]
            raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)
            raw_w.append(raw_w_i_groups)

            #feature transformation function g
            ref_i = self.conv_match(ref)
            shape_ref = ref_i.shape
            #sampling
            w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride, self.stride],
                                  rates=[1, 1],
                                  padding='same')
            w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
            w_i = w_i.permute(0, 4, 1, 2, 3)    # w shape: [N, L, C, k, k]
            w_i_groups = torch.split(w_i, 1, dim=0)
            w.append(w_i_groups)

        y = []
        for idx, xi in enumerate(input_groups):
            #group in a filter
            wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))],dim=0)  # [L, C, k, k]
            #normalize
            max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
                                                     axis=[1, 2, 3],
                                                     keepdim=True)),
                               self.escape_NaN)
            wi_normed = wi/ max_wi
            #matching
            xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(xi, wi_normed, stride=1)   # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
            yi = yi.view(1,wi.shape[0], shape_base[2], shape_base[3])  # (B=1, C=32*32, H=32, W=32)
            # softmax matching score
            yi = F.softmax(yi*self.softmax_scale, dim=1)
            
            if self.average == False:
                yi = (yi == yi.max(dim=1,keepdim=True)[0]).float()
            
            # deconv for patch pasting
            raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))],dim=0)
            yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride,padding=1)/4.
            y.append(yi)
      
        y = torch.cat(y, dim=0)+res*self.res_scale  # back to the mini-batch
        return y

2.PANet
在这里插入图片描述
可以看出,结构同EDSR等网络类似,区别就是在网络中间加上了个PA Block,即Pyramid Attention进行多尺度的自注意力提取全局上下文信息。

Experiments:

1.超分辨率任务上的结果优于SAN:
在这里插入图片描述

2.Ablation study:

(1)Non-Local Attention vs PA,多尺度更具优越性:
在这里插入图片描述
(2)逐像素匹配 vs 图像块匹配:
在这里插入图片描述
(3)特征金字塔层数:
在这里插入图片描述
金字塔注意力相关图的可视化。 为了可视化目的,将特征图缩放到相同大小。 颜色越亮表示参与度越高。 可以看到,注意力集中在每个尺度上的不同位置,表明该模块能够利用多尺度递归来改善恢复
在这里插入图片描述

(4)PA Block位置:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_42096202/article/details/106240801