图像超分辨率模型:Real-ESRGAN | 论文阅读+实战记录

前言

最近需要一个超分的模型,经过调研准备用 Real-ESRGAN。特此记录论文阅读和实战过程。

论文阅读

论文地址:Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
Github:https://github.com/xinntao/Real-ESRGAN
参考视频:https://www.bilibili.com/video/BV14541117y6

主要贡献:

  • 提出一种高阶退化过程(high-order degradation process)来模拟实际退化,并利用 sinc 滤波器给训练图片添加 Ringing artifacts(振铃伪影,周围震荡波的感觉)和 Overshoot artifacts(过冲伪影,如白边)构造训练集
  • 用 U-net 而不是 VGG 作为 GAN 的 Discriminator,提高鉴别器能力并稳定训练动态
  • Real-ESRGAN 性能更优,效果更好

效果对比:
在这里插入图片描述

数据集构建:
在这里插入图片描述
二阶退化。两阶段分别经过虚化、下采样、加噪声、JPEG压缩。其中 sinc filter 用来添加伪影。

伪影类型 实例
振铃伪影 在这里插入图片描述
过冲伪影 在这里插入图片描述

Real-ESRGAN 模型结构:

  • Generator:结构和 ESRGAN 一样,不过他做了个 Pixel Unshuffle,减少图像长宽,增加 channel size。送到一个 RRDB 的残差网络最后上采样得到输出。
    在这里插入图片描述
  • Discriminator:用的 U-Net(下采样和上采样的feature之间有连接,能学习局部纹理信息),不像最初的 GAN 判别器输出 0/1(即全局是否合格),他这个输出是和原图size一致,其中每个像素点的值衡量真实性的 feedback(即各局部是否合格)。此外,用了 spectral norm(谱归一化),能提升训练稳定性,减少伪影。
    在这里插入图片描述

两阶段模型训练:
首先,用 L1 loss 训练一个小网络(PSNR-oriented model)Real-ESRNet;然后,用其初始化 Generator 的权重,并用 L1 loss、perceptual loss、GAN loss 的结合训练最终的模型。

此外,还对 ground-truth 加了锐化再训了个 Real-ESRGAN+,能提升图像生成的锐度,但不会增加伪影。

实战记录

git clone https://github.com/xinntao/Real-ESRGAN.git
cd Real-ESRGAN
# Install basicsr - https://github.com/xinntao/BasicSR
# We use BasicSR for both training and inference
pip install basicsr
# facexlib and gfpgan are for face enhancement
pip install facexlib
pip install gfpgan
pip install -r requirements.txt
python setup.py develop

需要用到 realesrgan 模块的 RealESRGANer 类。以4倍超分为例:

import os
import cv2
import torch
import numpy as np
from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

ckpt_path = "./checkpoints/real-esrgan"
model_path = os.path.join(ckpt_path, "RealESRGAN_x4plus.pth")
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4

upsampler = RealESRGANer(
	scale=netscale,
	model_path=self.model_path,
	dni_weight=self.dni_weight,
	model=self.model,
	tile=0,                     # Tile size, 0 for no tile during testing
	tile_pad=10,                # Tile padding
	pre_pad=0,                  # Pre padding size at each border
	half=not fp16,
	device=device)

def enhance(image, width, height):
    """
        image: PIL Image Obj
        输出: PIL Image Obj
    """
    try:
        image_cv2 = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
        output, _ = upsampler.enhance(image_cv2, outscale=self.outscale)  # _ is RGB/RGBA
        image_pil = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)).resize((width, height)).convert('RGB')
        # print(output, _, image_pil)
        return image_pil
    except Exception as e:
        print("enhance Exception: ", e)
    finally:
        torch.cuda.empty_cache()

测试(超分前):
在这里插入图片描述

超分后:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/muyao987/article/details/127960309
今日推荐