《Deep Iterative Down-Up CNN for Image Denoising》论文阅读

一、论文

《Deep Iterative Down-Up CNN for Image Denoising》

由于GPU内存的有效使用及其产生大接收场的能力,在低级视觉研究中已广泛研究了使用按比例缩小和按比例放大特征图的网络。在本文中,我们提出了一种用于图像去噪的深度迭代向下卷积神经网络(DIDN),该算法反复减少和增加特征图的分辨率。网络的基本结构受到U-Net的启发,U-Net最初是为语义分段而开发的。我们针对图像去噪任务修改了按比例缩小和按比例放大的图层。 训练常规降噪网络以处理单级噪声,或者可替代地使用噪声信息作为输入,以使用单个模型解决多级噪声。 相反,由于我们网络的高效内存使用使其能够处理多个参数,因此它能够使用单个模型处理各种噪声级别,而无需输入噪声信息作为解决方法。 因此,我们的DIDN使用基准数据集展示了最先进的性能,还展示了其在NTIRE 2019真实图像去噪挑战中的优越性。

总而言之,这项工作的贡献如下:

(1)一种新颖的CNN架构,迭代地收缩和扩展具有很大接收域的特征。

(2)修改了U-Net用于图像去噪任务的缩小和放大过程。

(3)将加权平均技术应用于高斯噪声图像去噪,从而获得了一个更通用和性能增强的模型,而没有附加参数。

(4)一种有效的方法来训练单个模型,使其可以处理多级噪声(未知噪声)而无需输入噪声信息。

 (5) 最先进的高斯图像降噪性能。

二、模型结构

U-Net [24]最初是用于语义分割的。  U-Net由两条路径组成:收缩路径可减小深层特征的尺寸,而扩展路径可增大这些特征的尺寸。  U-Net的核心原理是降低特征的分辨率以增加接收域,然后通过匹配分辨率级别的级联来重用特征,以最大程度地减小由向下缩放所引起的信息丢失。  形结构的效率在图像去噪中得到了验证[14],并归因于其大的接收场,高的GPU内存效率和较低的计算成本。  DBPN [25]证明了特征图的迭代向下缩放对于学习图像超分辨率任务是有效的。 但是,此方法增加了特征图的整体大小,减小了接收场并增加了计算复杂度。利用[24]和[25],我们提出了一种称为深度迭代下降网络(DIDN)的迭代下降规模缩放网络,通过顺序重复收缩和扩展过程,提供了较大的接收范围和有效的GPU内存使用率。 图2显示了(a)DIDN的体系结构(b)DUB的结构图2:提议的DIDN的体系结构。步幅为2个子像素层,缩放比例为2。 在这里,灰色块代表特征图,特征图构建由四个不同分辨率级别组成的分层结构。  DIDN由四个部分组成:特征提取,向下块(DUB),重建和增强。

初始特征提取:当输入图像的大小为时,DIDN首先在输入图像上使用3*3卷积提取ܰ个特征,然后通过步长为2的卷积层提取个大小的特征 。

DUB:提取的特征经过多个DUB的迭代缩减缩放。 在DUB中,收缩和扩展是通过两个缩小和放大过程执行的。 跨度为2的3×3卷积层和子像素层分别用于缩小和放大。 在缩小过程中,特征图的大小在水平和垂直方向上减小一半,并且特征数量增加一倍。 在放大过程中,由于通过子像素层将输入特征的数量减少了四分之一,因此在子像素层之前通过1ൈ1卷积层增加了特征图的数量,以保持信息密度。 就像在U-Net [24]中一样,将相同分辨率级别的特征连接在一起,以增加这些特征在层次结构中的重用。 块开头和结尾的功能通过跳过连接[21]进行链接。

重构:受MemNet [9]的启发,我们在最后一个DUB之后放置了一个通用的重构块,以利用所有本地输出。 所有DUB的输出形成重建模块的输入,并且重建模块的所有输出被串联以经过增强阶段。 重建块由九个卷积层(Conv)和参数化整流线性单元(PReLU)组成[38]。 更具体地说,有四个连续的残差块,由“ Conv + PReLU + Conv + PReLU”组成,最后还有其他Conv。

增强功能:最后,通过1*1卷积,减少了重建块中输出特征图的数量,并在子像素层执行了放大操作,以生成最终的去噪图像。

三、代码

下载:https://github.com/SonghyunYu/DIDN/blob/master/gray_model.py

import torch
import torch.nn as nn
import math


class _Residual_Block(nn.Module):
    def __init__(self):
        super(_Residual_Block, self).__init__()

        #res1
        self.conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu2 = nn.PReLU()
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu4 = nn.PReLU()
        #res1
        #concat1

        self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1, bias=False)
        self.relu6 = nn.PReLU()

        #res2
        self.conv7 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu8 = nn.PReLU()
        #res2
        #concat2

        self.conv9 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1, bias=False)
        self.relu10 = nn.PReLU()

        #res3
        self.conv11 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu12 = nn.PReLU()
        #res3

        self.conv13 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1, stride=1, padding=0, bias=False)
        self.up14 = nn.PixelShuffle(2)

        #concat2
        self.conv15 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, bias=False)
        #res4
        self.conv16 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu17 = nn.PReLU()
        #res4

        self.conv18 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1, padding=0, bias=False)
        self.up19 = nn.PixelShuffle(2)

        #concat1
        self.conv20 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False)
        #res5
        self.conv21 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu22 = nn.PReLU()
        self.conv23 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu24 = nn.PReLU()
        #res5

        self.conv25 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)




    def forward(self, x):
        res1 = x
        out = self.relu4(self.conv3(self.relu2(self.conv1(x))))
        out = torch.add(res1, out)
        cat1 = out

        out = self.relu6(self.conv5(out))
        res2 = out
        out = self.relu8(self.conv7(out))
        out = torch.add(res2, out)
        cat2 = out

        out = self.relu10(self.conv9(out))
        res3 = out

        out = self.relu12(self.conv11(out))
        out = torch.add(res3, out)

        out = self.up14(self.conv13(out))

        out = torch.cat([out, cat2], 1)
        out = self.conv15(out)
        res4 = out
        out = self.relu17(self.conv16(out))
        out = torch.add(res4, out)

        out = self.up19(self.conv18(out))

        out = torch.cat([out, cat1], 1)
        out = self.conv20(out)
        res5 = out
        out = self.relu24(self.conv23(self.relu22(self.conv21(out))))
        out = torch.add(res5, out)

        out = self.conv25(out)
        out = torch.add(out, res1)

        return out

class Recon_Block(nn.Module):
    def __init__(self):
        super(Recon_Block, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu2 = nn.PReLU()
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu4 = nn.PReLU()

        self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu6= nn.PReLU()
        self.conv7 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu8 = nn.PReLU()

        self.conv9 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu10 = nn.PReLU()
        self.conv11 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu12 = nn.PReLU()

        self.conv13 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu14 = nn.PReLU()
        self.conv15 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu16 = nn.PReLU()

        self.conv17 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)


    def forward(self, x):
        res1 = x
        output = self.relu4(self.conv3(self.relu2(self.conv1(x))))
        output = torch.add(output, res1)

        res2 = output
        output = self.relu8(self.conv7(self.relu6(self.conv5(output))))
        output = torch.add(output, res2)

        res3 = output
        output = self.relu12(self.conv11(self.relu10(self.conv9(output))))
        output = torch.add(output, res3)

        res4 = output
        output = self.relu16(self.conv15(self.relu14(self.conv13(output))))
        output = torch.add(output, res4)

        output = self.conv17(output)
        output = torch.add(output, res1)

        return output



class _NetG(nn.Module):
    def __init__(self):
        super(_NetG, self).__init__()

        self.conv_input = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu1 = nn.PReLU()
        self.conv_down = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, bias=False)
        self.relu2 = nn.PReLU()

        self.recursive_A = _Residual_Block()
        self.recursive_B = _Residual_Block()
        self.recursive_C = _Residual_Block()
        self.recursive_D = _Residual_Block()
        self.recursive_E = _Residual_Block()
        self.recursive_F = _Residual_Block()

        self.recon = Recon_Block()
        #concat

        self.conv_mid = nn.Conv2d(in_channels=1536, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False)
        self.relu3 = nn.PReLU()
        self.conv_mid2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu4 = nn.PReLU()

        self.subpixel = nn.PixelShuffle(2)
        self.conv_output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)




    def forward(self, x):
        residual = x
        out = self.relu1(self.conv_input(x))
        out = self.relu2(self.conv_down(out))

        out1 = self.recursive_A(out)
        out2 = self.recursive_B(out1)
        out3 = self.recursive_C(out2)
        out4 = self.recursive_D(out3)
        out5 = self.recursive_E(out4)
        out6 = self.recursive_F(out5)

        recon1 = self.recon(out1)
        recon2 = self.recon(out2)
        recon3 = self.recon(out3)
        recon4 = self.recon(out4)
        recon5 = self.recon(out5)
        recon6 = self.recon(out6)

        out = torch.cat([recon1, recon2, recon3, recon4, recon5, recon6], 1)

        out = self.relu3(self.conv_mid(out))
        residual2 = out
        out = self.relu4(self.conv_mid2(out))
        out = torch.add(out, residual2)

        out= self.subpixel(out)
        out = self.conv_output(out)
        out = torch.add(out, residual)

        return out

四、学习资料

IR_DIDN论文总结+结构实现

猜你喜欢

转载自blog.csdn.net/LiuJiuXiaoShiTou/article/details/106162031