《Toward Convolutional Blind Denoising of Real Photographs》阅读笔记

一、论文

《Toward Convolutional Blind Denoising of Real Photographs》

摘要:尽管深卷积神经网络(CNN)在加性高斯白噪声(AWGN)的图像去噪方面取得了令人瞩目的成功,但其性能在实际嘈杂的照片上仍然受到限制。 主要原因是他们的学习模型很容易在简化的AWGN模型上过度拟合,而AWGN模型与复杂的实际噪声模型大相径庭。 为了提高深层CNN去噪器的泛化能力,我们建议使用更逼真的噪声模型和真实的噪声清洁图像对来训练卷积盲去噪网络(CBDNet)。 一方面,信号噪声和机内信号处理管道都被认为可以合成真实的噪点图像。 另一方面,还包括现实世界中嘈杂的照片及其几乎无噪音的照片,以训练我们的CBDNet。 为了进一步提供一种交互式策略以方便地校正去噪结果,将具有非对称学习的噪声估计子网嵌入到CBDNet中,以抑制噪声水平的过低估计。 在现实世界中嘈杂照片的三个数据集上的大量实验结果清楚地表明,就定量指标和视觉质量而言,CBDNet的性能优于最新技术。 该代码已在https://github.com/GuoShi28/CBDNet提供。

二、学习资料

论文笔记:Toward Convolutional Blind Denoising of Real Photographs

Toward Convolutional Blind Denoising of Real Photographs

三、模型结构

  • 噪声等级子网络由五层的卷积组成,卷积核大小为 3*3,通道数为 32,激活函数采用 Relu,没有采用池化和批归一化,输出的噪声等级图和原噪声图片大小相同。

  • 去噪子网络将噪声等级图和原噪声图片一起作为输入,采用了 U-Net 的网络结构,卷积核大小为 3*3,激活函数采用 Relu,学习噪声图片的残差。

为了利用盲降噪中的不对称灵敏度,我们提出了噪声估计中的不对称损失,以避免在噪声水平图上出现估计不足误差。 给定像素i处的估计噪声水平和地面真实度σ,当时,应对其MSE施加更多的惩罚。 因此,我们将噪声估计子网中的非对称损耗定义为:

,否则为0。 通过设置0 <α<0.5,我们可以对低估误差施加更多的惩罚,以使模型很好地推广到实际噪声。 此外,我们引入了总变化量(TV)调节器来约束的平滑度, 

其中表示沿水平(垂直)方向的梯度算子。 对于非盲消噪的输出xˆ,我们将重建损失定义为:

综上所述,我们CBDNet的总体目标是:

其中分别表示非对称损耗和TV正则器的权衡参数。

四、训练过程

  • 基于真实噪声模型合成的图片和真实的噪声图片被联合在一起对网络进行训练,来增强网络处理真实图像的泛化能力。

  • 针对一个批次的合成图片, 三个损失都被计算来训练网络。

  • 针对一个批次的真实,由于噪声等级不可知,因此只有两个损失被计算来训练网络。

五、代码

github https://github.com/IDKiro/CBDNet-pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F


class CBDNet(nn.Module):
    def __init__(self):
        super(CBDNet, self).__init__()
        self.fcn = FCN()
        self.unet = UNet()
    
    def forward(self, x):
        noise_level = self.fcn(x)
        concat_img = torch.cat([x, noise_level], dim=1)
        out = self.unet(concat_img) + x
        return noise_level, out


class FCN(nn.Module):
    def __init__(self):
        super(FCN, self).__init__()
        self.inc = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.conv = nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.outc = nn.Sequential(
            nn.Conv2d(32, 3, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        conv1 = self.inc(x)
        conv2 = self.conv(conv1)
        conv3 = self.conv(conv2)
        conv4 = self.conv(conv3)
        conv5 = self.outc(conv4)
        return conv5


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        self.inc = nn.Sequential(
            single_conv(6, 64),
            single_conv(64, 64)
        )

        self.down1 = nn.AvgPool2d(2)
        self.conv1 = nn.Sequential(
            single_conv(64, 128),
            single_conv(128, 128),
            single_conv(128, 128)
        )

        self.down2 = nn.AvgPool2d(2)
        self.conv2 = nn.Sequential(
            single_conv(128, 256),
            single_conv(256, 256),
            single_conv(256, 256),
            single_conv(256, 256),
            single_conv(256, 256),
            single_conv(256, 256)
        )

        self.up1 = up(256)
        self.conv3 = nn.Sequential(
            single_conv(128, 128),
            single_conv(128, 128),
            single_conv(128, 128)
        )

        self.up2 = up(128)
        self.conv4 = nn.Sequential(
            single_conv(64, 64),
            single_conv(64, 64)
        )

        self.outc = outconv(64, 3)

    def forward(self, x):
        inx = self.inc(x)

        down1 = self.down1(inx)
        conv1 = self.conv1(down1)

        down2 = self.down2(conv1)
        conv2 = self.conv2(down2)

        up1 = self.up1(conv2, conv1)
        conv3 = self.conv3(up1)

        up2 = self.up2(conv3, inx)
        conv4 = self.conv4(up2)

        out = self.outc(conv4)
        return out


class single_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(single_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch):
        super(up, self).__init__()
        self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))

        x = x2 + x1
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x


class fixed_loss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, out_image, gt_image, est_noise, gt_noise, if_asym):
        h_x = est_noise.size()[2]
        w_x = est_noise.size()[3]
        count_h = self._tensor_size(est_noise[:, :, 1:, :])
        count_w = self._tensor_size(est_noise[:, :, : ,1:])
        h_tv = torch.pow((est_noise[:, :, 1:, :] - est_noise[:, :, :h_x-1, :]), 2).sum()
        w_tv = torch.pow((est_noise[:, :, :, 1:] - est_noise[:, :, :, :w_x-1]), 2).sum()
        tvloss = h_tv / count_h + w_tv / count_w

        loss = torch.mean(torch.pow((out_image - gt_image), 2)) + \
                if_asym * 0.5 * torch.mean(torch.mul(torch.abs(0.3 - F.relu(gt_noise - est_noise)), torch.pow(est_noise - gt_noise, 2))) + \
                0.05 * tvloss
        return loss

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]

猜你喜欢

转载自blog.csdn.net/LiuJiuXiaoShiTou/article/details/106056721