【超分辨率】Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/shwan_ma/article/details/81429292

之前我一直在做基于CNN的超分辨率研究。最近因为工作需要,需要研究基于生成对抗网络GAN的网络来做超分辨率任务。
在这段时间以来,我发现CNN和GAN两类网络的侧重点其实完全不同。CNN旨在于忠实的恢复图像的高频信息,而GAN在于生成更真实或者说更符合人眼的高分辨率图片。

本文论文名:《Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network》
文章链接:https://arxiv.org/abs/1609.04802v1
code: https://github.com/tensorlayer/SRGAN

基于GAN和基于MSE的超分辨率图像的侧重点


这里写图片描述

MSE-based解决方案由于像素化的平均值,会显得过于平滑,而GAN驱动的超分辨率解决方案则对自然图像集合的重构产生了感知上更有说服力的解,看起来更加自然

本文摘要:

在CNN的发展中,超分辨率任务在精确度和速度上都有了很大的发展。但是在较大的上采样因子下,如何去恢复更精细的纹理信息仍然等待着处理。目前现有的超分辨率算法产生的图像都具有很高的PSNR,但是这类算法产生的结果往往缺乏高频细节,而且感知上并不令人满意。在这篇文章中,作者提出用GAN来做图像超分辨率。在本文中,作者提出了一种感知损失perceptual loss,该loss function包含了对抗损失及内容损失。通过这种loss来驱动,使得生成的高分辨率图像更加自然。大量的主观评分(MOS)证明,SRGAN产生的图片更接近真实图像

本文contributon:

1) 本文提出了SRResNet CNN框架,并且以PSNR和SSIM为优化目的,构造了16个blocks的deep ResNet
2) 本文提出了基于GAN的并以perceptual loss 为目的的SRGAN网络, 在训练GAN网络时,本文加入VGG loss作为优化目的,VGG loss在像素空间更具有不变性。
3)本文采取了MOS主观意见评价方法,对SRGAN生成的图像进行打分。

网络结构:

这里写图片描述

Generator 通过deep ResNet来学习 LR到HR之间的映射,并且以PSNR和SSIM作为其生成指标。同时用判别器来判别生成的图片是否属于自然图像。

整个网络结构还是生成对抗网络的套路,目的是去优化min-max problem
这里写图片描述

整个目标函数在GAN中非常常见,也是GAN的本质,其目的是为了训练一个生成器G来fool 判别器D,再训练判别器D来判别生成器G生成的图片还是自然真实图像。

生成器loss

content loss:

该损失函数在超分辨率网络中极为常见,即通过求MSE最小来对生成器网络进行更新,然而MSE 优化生成的超分图像往往会缺少高频信息,从而使得过平滑。
这里写图片描述

VGG loss:

这个损失函数来自于李飞飞的感知loss的一篇论文,通过将两张图片投入VGG网络中,然后求解两张特征图像的mse来进行优化。由于VGG loss来自于比较深层的网络提取出来的特征,因此这个损失更能够保证感知相似度。
这里写图片描述

Adversarial loss:

这个损失函数在GAN的生成器中极为常见,一般来说如果只用对抗损失会使得网络训练起来很难收敛,而加入了之前的MSE loss和VGG loss后,能够保证网络的收敛。
这里写图片描述

为了简化对抗损失 adversarial loss,我们对上式 用 l o g D θ D ( G θ D ( I L R ) )
来替换 l o g [ 1 D θ D ( G θ D ( I L R ) ) ]

那么生成器的loss则很明确了

(70) g l o s s = g c o n t e n t l o s s + g V G G l o s s + g a d v e r s a r i a l

代码描述为:

    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
    mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)
    vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

    g_loss = mse_loss + vgg_loss + g_gan_loss

判别器loss

判别器loss只有对抗损失函数
这里写图片描述

代码描述为:

    d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')
    d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')
    d_loss = d_loss1 + d_loss2

实验结果:

这里写图片描述

猜你喜欢

转载自blog.csdn.net/shwan_ma/article/details/81429292