《MemNet: A Persistent Memory Network for Image Restoration》阅读笔记

一、论文

MemNet: A Persistent Memory Network for Image Restoration

摘要:最近,非常深的卷积神经网络(CNN)在图像恢复中引起了相当大的关注。 但是,随着深度的增加,对于这些非常深的模型很少能实现长期依赖性问题,这导致先前的状态/层对后续的状态/层几乎没有影响。 受人类思想具有持久性这一事实的启发,我们提出了一个非常深的持久性内存网络(MemNet),该网络引入了一个包含递归单元和门单元的内存块,以通过自适应学习过程显式地挖掘持久性内存。 递归单元学习不同接受域下当前状态的多级表示。 先前存储块的表示和输出被串联并发送到门单元,该门单元自适应地控制应保留多少先前状态,并决定应存储多少当前状态。 我们将MemNet应用于三个图像恢复任务,即图像去噪,超分辨率和JPEG解块。 全面的实验证明了MemNet的必要性及其在所有三个任务上都比现有技术具有一致的优势。 可以从https://github.com/tyshiwo/MemNet获得代码。

一个用于完成门控机制以帮助桥接长期依赖关系的内存块。 在每个存储块中,门单元自适应地为不同的存储器学习不同的权重,该权重控制应保留多少长期存储器,并决定应存储多少短期存储器。

非常深的端到端持久存储网络(80个卷积层),用于图像恢复。 密集连接的结构有助于补偿中/高频信号,并确保最大的信息在存储块之间流动。 据我们所知,它是迄今为止图像还原最深入的网络。

相同的MemNet结构在图像去噪,超分辨率和JPEG解块方面实现了最先进的性能。 由于强大的学习能力,即使使用单一模型,我们的MemNet也可以接受训练以处理不同级别的损坏。

二、网络结构

图1.现有的网络结构(a,b)和我们的内存块(c)。蓝色圆圈表示具有展开结构​​的递归单元,该递归单元生成短期记忆。 绿色箭头表示直接传递到门单元的先前存储块中的长期存储。

在MemNet中,特征提取网(FENet)首先从低质量图像中提取特征。 然后,以密集连接的结构堆叠几个存储块以解决图像恢复任务。 最后,采用重建网络(ReconNet)来学习残差,而不是直接映射,以减轻训练难度。

论文写法非常值得借鉴,简单理解来说就是块内块间的连接组合。figure3的多监督输出还是值得学习的。

三、代码

链接:https://github.com/wutianyiRosun/MemNet

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# dtype = torch.FloatTensor
dtype = torch.cuda.FloatTensor  # Uncomment this to run on GPU


class MemNet(nn.Module):
    def __init__(self, in_channels, channels, num_memblock, num_resblock):
        super(MemNet, self).__init__()
        self.feature_extractor = BNReLUConv(in_channels, channels, True)  # FENet: staic(bn)+relu+conv1
        self.reconstructor = BNReLUConv(channels, in_channels, True)  # ReconNet: static(bn)+relu+conv
        self.dense_memory = nn.ModuleList(
            [MemoryBlock(channels, num_resblock, i + 1) for i in range(num_memblock)]
        )
        # ModuleList can be indexed like a regular Python list, but modules it contains are
        # properly registered, and will be visible by all Module methods.

        self.weights = nn.Parameter((torch.ones(1, num_memblock) / num_memblock), requires_grad=True)
        # output1,...,outputn corresponding w1,...,w2

    # Multi-supervised MemNet architecture
    def forward(self, x):
        residual = x
        out = self.feature_extractor(x)
        w_sum = self.weights.sum(1)
        mid_feat = []  # A lsit contains the output of each memblock
        ys = [out]  # A list contains previous memblock output(long-term memory)  and the output of FENet
        for memory_block in self.dense_memory:
            out = memory_block(out, ys)  # out is the output of GateUnit  channels=64
            mid_feat.append(out);
        # pred = Variable(torch.zeros(x.shape).type(dtype),requires_grad=False)
        pred = (self.reconstructor(mid_feat[0]) + residual) * self.weights.data[0][0] / w_sum
        for i in range(1, len(mid_feat)):
            pred = pred + (self.reconstructor(mid_feat[i]) + residual) * self.weights.data[0][i] / w_sum

        return pred

    # Base MemNet architecture
    '''
    def forward(self, x):
        residual = x   #input data 1 channel
        out = self.feature_extractor(x)
        ys = [out]  #A list contains previous memblock output and the output of FENet
        for memory_block in self.dense_memory:
            out = memory_block(out, ys)
        out = self.reconstructor(out)
        out = out + residual

        return out
    '''


class MemoryBlock(nn.Module):
    """Note: num_memblock denotes the number of MemoryBlock currently"""

    def __init__(self, channels, num_resblock, num_memblock):
        super(MemoryBlock, self).__init__()
        self.recursive_unit = nn.ModuleList(
            [ResidualBlock(channels) for i in range(num_resblock)]
        )
        # self.gate_unit = BNReLUConv((num_resblock+num_memblock) * channels, channels, True)  #kernel 3x3
        self.gate_unit = GateUnit((num_resblock + num_memblock) * channels, channels, True)  # kernel 1x1

    def forward(self, x, ys):
        """ys is a list which contains long-term memory coming from previous memory block
        xs denotes the short-term memory coming from recursive unit
        """
        xs = []
        residual = x
        for layer in self.recursive_unit:
            x = layer(x)
            xs.append(x)

        # gate_out = self.gate_unit(torch.cat([xs,ys], dim=1))
        gate_out = self.gate_unit(torch.cat(xs + ys, 1))  # where xs and ys are list, so concat operation is xs+ys
        ys.append(gate_out)
        return gate_out


class ResidualBlock(torch.nn.Module):
    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    x - Relu - Conv - Relu - Conv - x
    """

    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.relu_conv1 = BNReLUConv(channels, channels, True)
        self.relu_conv2 = BNReLUConv(channels, channels, True)

    def forward(self, x):
        residual = x
        out = self.relu_conv1(x)
        out = self.relu_conv2(out)
        out = out + residual
        return out


class BNReLUConv(nn.Sequential):
    def __init__(self, in_channels, channels, inplace=True):
        super(BNReLUConv, self).__init__()
        self.add_module('bn', nn.BatchNorm2d(in_channels))
        self.add_module('relu',
                        nn.ReLU(inplace=inplace))  # tureL: direct modified x, false: new object and the modified
        self.add_module('conv',
                        nn.Conv2d(in_channels, channels, 3, 1, 1))  # bias: defautl: ture on pytorch, learnable bias


class GateUnit(nn.Sequential):
    def __init__(self, in_channels, channels, inplace=True):
        super(GateUnit, self).__init__()
        self.add_module('bn', nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=inplace))
        self.add_module('conv', nn.Conv2d(in_channels, channels, 1, 1, 0))


import torch
from torchsummary import summary

# 需要使用device来指定网络在GPU还是CPU运行
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
netG_A2B = MemNet(1,64,2,2)
summary(netG_A2B, input_size=(1, 256, 256))

 去噪性能

 

猜你喜欢

转载自blog.csdn.net/LiuJiuXiaoShiTou/article/details/107746663