40,000-word long text interpretation image super-resolution Real-ESRGAN paper notes + code reading

Table of contents

1. Introduction

2. Key innovations

1.ESRGAN 

2.Real-ESRGAN

3. Generator structure

1. Overall structure

2. RRDB structure

4. Discriminator structure

5. High-order degradation model

6. Loss function

1. Generate model loss function

2. Discriminant model loss function


1. Introduction

        Super-resolution (Super-Resolution) refers to the process of increasing the resolution of the original image through hardware or software, and obtaining a high-resolution image through a series of low-resolution images. In layman's terms, it is to enlarge the image under the premise of keeping the sharpness of the original image unchanged. Using the deep learning model for super-resolution is already a relatively common method, and another advantage of the deep learning model is that the data can be degraded when the data is enhanced, and it can also be deblurred, denoised, and delineated during the super-resolution. operations such as marks.        

        The deep learning super-resolution model has several milestones: SRCNN > SRGAN > ESRGAN > Real-ESRGAN, SRCNN and SRGAN are a bit old, and they are basically not used now. Real-ESRGAN is an upgrade based on ESRGAN, so we mainly introduce Real-ESRGAN, supplemented by ESRGAN .

        ESRGAN paper address: http://arxiv.org/abs/1609.04802

        Real-ESRGAN paper address:   https://arxiv.org/abs/2107.10833v2

        Code address: GitHub - oaifaye/dcm-denoise-SR

2. Key innovations

1.ESRGAN 

        (1) Propose a new backbone: RRDB (Residual in Residual Dense Block). The Dense here refers not to the full connection but to the dense residual links in the convolutional layer. The advantage of this is that a deeper and more complex structure can be obtained, and the network capacity becomes higher.

        (2) Delete the BN layer. The author found that the BN layer is relatively deep in the network, and when it is trained under the GAN framework, it will produce artifacts and reduce the stability and consistency of the training. In addition, removing the BN layer can also improve the generalization ability of the model and reduce computational complexity and memory usage.

        (3) Network Interpolation ( Network Interpolation ), or residual scaling. The residual information is multiplied by a number between 0 and 1 (0.2 is finally determined through experiments), which can make the training more stable and reduce artifacts while maintaining the texture.

        (4) The discriminator is improved using relativistic RaGAN, which learns to judge "whether an image is more real than another image" instead of "whether an image is real or fake". The graph given in the paper is very vivid. The VGG used by the backbone is replaced in Real-ESRGAN. And the discriminator of RaGAN is not used in Real-ESRGAN...

2.Real-ESRGAN

        The optimization of Real-ESRGAN is done on the basis of ESRGAN, the main contents are as follows:

        (1) A high-order degradation process of data is given. That is, splicing several typical degradation processes to model (including sinc filter), so as to obtain low-quality images that are closer to reality. Finally, the authors adopted a second-order degradation process in order to achieve a good balance between simplicity and effectiveness. This is very important, we will focus on it later.

        (2) The discriminator uses U-Net instead of VGG. The discriminator in Real-ESRGAN requires greater discriminative power for complex training outputs, and it also needs to generate accurate gradient feedback for local textures instead of only distinguishing global patterns. Therefore, a more powerful U-Net is used as the discriminator. Outputs the realism value of each pixel, and can provide detailed per-pixel feedback to the generator, which enhances the adversarial learning of image details. We will also focus on the discriminator below.

        (3) Introduce Spectral Normalization to stabilize the training instability caused by complex data sets and U-Net discriminators.

3. Generator structure

1. Overall structure

        Let's take batch_size=1 and input 64x64 4x superscore as an example. The overall structure of the generator is as follows:

        It can be seen that the overall model structure is not complicated, and it is generally a sequential structure. The data has passed through 23 RRDB modules, and each RDDB block is composed of 3 ResidualDenseBlocks. The input and output shapes are the same; then Unsample is performed twice, and Unsample uses Nearest interpolation, after each Unsample, there will be a convolutional layer to refine the interpolation details; the final number of channels becomes 3 outputs.

        In fact, the general mechanism of the generator is the same as that of SRGAN, but the 16 residual blocks before Unsample are replaced with 23 RRDB modules, which greatly improves the feature extraction ability, which is why SRGAN can restore pictures very well. The reason for the details. Each RDDB block consists of 3 ResidualDenseBlocks. Before adding at the bottom, the network interpolation mentioned above is used, that is, the output is multiplied by 0.2 and then added to the output, which improves the stability of training.

Code:

# 位置 basicsr/archs/rrdbnet_arch.py
class RRDBNet(nn.Module):
    """Networks consisting of Residual in Residual Dense Block, which is used
    in ESRGAN.

    ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.

    We extend ESRGAN for scale x2 and scale x1.
    Note: This is one option for scale 1, scale 2 in RRDBNet.
    We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
    and enlarge the channel size before feeding inputs into the main ESRGAN architecture.

    Args:
        num_in_ch (int): Channel number of inputs.
        num_out_ch (int): Channel number of outputs.
        num_feat (int): Channel number of intermediate features.
            Default: 64
        num_block (int): Block number in the trunk network. Defaults: 23
        num_grow_ch (int): Channels for each growth. Default: 32.
    """

    def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
        super(RRDBNet, self).__init__()
        self.scale = scale
        if scale == 2:
            num_in_ch = num_in_ch * 4
        elif scale == 1:
            num_in_ch = num_in_ch * 16
        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
        self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # upsample
        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        if self.scale == 2:
            feat = pixel_unshuffle(x, scale=2)
        elif self.scale == 1:
            feat = pixel_unshuffle(x, scale=4)
        else:
            feat = x
        feat = self.conv_first(feat)
        # 23个RRDB
        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat
        # upsample
        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
        return out

2. RRDB structure

        The core of Real-ESRGAN is RRDB, which is characterized by dense residual links. At the same time, the two ends of the residual side are connected by Concat. The structure diagram is as follows:

         Is it very dense, it looks lively, but it is actually regular, that is, the output of each convolutional activation layer will be used as the input of all the nodes below. Because there are 4 concat operations, the degree (out degree + in degree) of each Concat node is 4.

Code:

# 位置 basicsr/archs/rrdbnet_arch.py
class ResidualDenseBlock(nn.Module):
    """Residual Dense Block.

    Used in RRDB block in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features.
        num_grow_ch (int): Channels for each growth.
    """

    def __init__(self, num_feat=64, num_grow_ch=32):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        # Empirically, we use 0.2 to scale the residual for better performance
        return x5 * 0.2 + x


class RRDB(nn.Module):
    """Residual in Residual Dense Block.

    Used in RRDB-Net in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features.
        num_grow_ch (int): Channels for each growth.
    """

    def __init__(self, num_feat, num_grow_ch=32):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        # Empirically, we use 0.2 to scale the residual for better performance
        return out * 0.2 + x

4. Discriminator structure

        The discriminator uses U-Net with spectral normalization, the structure is as follows:

         The discriminator is divided into three parts:

        Downsample: Three layers of convolutional layers with spectral normalization, each channel is doubled, and the width and height are halved.

        Unsample: Use nearest interpolation for upsampling, three layers of convolutional layers with spectral normalization, each channel is halved, the width and height are doubled, and there is a residual edge connected to Downsample.

         Output layer: Two layers of convolution with spectral normalization and one convolutional output layer.

5. High-order degradation model

        The High-order Degradation Model is the most important innovation of Real-ESRGAN. Classical degradation models cannot simulate some complex degradation problems, especially unknown noise and complex artifacts, because there is still a large gap between the synthesized low-resolution image and the realistic degraded image. Therefore, Real-ESRGAN extends the classical degradation models to higher-order processes to simulate more realistic degradations.

        In layman's terms, the so-called high-order degradation model is the arrangement and combination of classic degradation algorithms. In this paper, the degradation algorithms are divided into four categories: Blur, Resize, Noise, and JPEG Compression, as shown in the following figure:

        It can be seen from the code that the entire degradation model cycles through the above four degradation processes twice, and an algorithm is randomly selected for each process. The steps are as follows: 

        1.1 Blur: Probability chooses to use sinc filter or other fuzzy algorithms (iso/aniso/generalized_iso/generalized_aniso/plateau_iso/plateau_aniso), and the default probability of sinc filter is 10%. The sinc filter is to simulate ring artifacts (ring artifacts) and overshoot artifacts (overshoot artifacts). The two artifacts look like this:

        1.2 Resize: randomly zoom in or out, choose one of the interpolation methods area/bilinear/bicubic;

        1.3 Noise: randomly select gaussian/poisson for the noise distribution; randomly select color/gray for the noise form, color noise means that the values ​​of the three channels are different (default probability 60%), and gray noise has the same value in the three channels (default probability 40%);

        1.4 JPEG compression: JPEG compression, the default quality is 30-950;

        2.1 Blur: 80% probability of execution by default, same as 1.1;

        2.2 Resize: Same as 1.2;

        2.3 Noise: Same as 1.3;

        2.4 JPEG compression: This step is special, there are two combinations [resize back + sinc filter] + JPEG compression / 
JPEG compression + [resize back + sinc filter], where resize back is the mutation resize into gt_size

       The code of random degenerate kernels is in realesrgan_dataset.py, the code is as follows:

# 位置 realesrgan/data/realesrgan_dataset.py

......
# ------------------------ 随机生成第一步的各种退化核 ------------------------ #
        kernel_size = random.choice(self.kernel_range)
        # 概率选择使用sinc filter还是其他模糊算法,sinc filter概率默认10%
        if np.random.uniform() < self.opt['sinc_prob']:
            # this sinc filter setting is for kernels ranging from [7, 21]
            if kernel_size < 13:
                omega_c = np.random.uniform(np.pi / 3, np.pi)
            else:
                omega_c = np.random.uniform(np.pi / 5, np.pi)
            kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
        else:
            # blur
            kernel = random_mixed_kernels(
                self.kernel_list,
                self.kernel_prob,
                kernel_size,
                self.blur_sigma,
                self.blur_sigma, [-math.pi, math.pi],
                self.betag_range,
                self.betap_range,
                noise_range=None)
        # pad kernel
        pad_size = (21 - kernel_size) // 2
        kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))

        # ------------------------ 随机生成第一步的各种退化核 ------------------------ #
        kernel_size = random.choice(self.kernel_range)
        if np.random.uniform() < self.opt['sinc_prob2']:
            if kernel_size < 13:
                omega_c = np.random.uniform(np.pi / 3, np.pi)
            else:
                omega_c = np.random.uniform(np.pi / 5, np.pi)
            kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
        else:
            kernel2 = random_mixed_kernels(
                self.kernel_list2,
                self.kernel_prob2,
                kernel_size,
                self.blur_sigma2,
                self.blur_sigma2, [-math.pi, math.pi],
                self.betag_range2,
                self.betap_range2,
                noise_range=None)

        # pad kernel
        pad_size = (21 - kernel_size) // 2
        kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))

        # ------------------------------------- 随机最后一部中的 sinc kernel ------------------------------------- #
        if np.random.uniform() < self.opt['final_sinc_prob']:
            kernel_size = random.choice(self.kernel_range)
            omega_c = np.random.uniform(np.pi / 3, np.pi)
            sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
            sinc_kernel = torch.FloatTensor(sinc_kernel)
        else:
            sinc_kernel = self.pulse_tensor
......

        Execute the large code of the degradation process:

# realesrgan/models/realesrgan_model.py
......
# ----------------------- The first degradation process ----------------------- #
    # 1.1 执行blur
    out = filter2D(self.gt_usm, self.kernel1)
    # 1.2 执行random resize
    updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
    if updown_type == 'up':
	scale = np.random.uniform(1, self.opt['resize_range'][1])
    elif updown_type == 'down':
	scale = np.random.uniform(self.opt['resize_range'][0], 1)
    else:
	scale = 1
    mode = random.choice(['area', 'bilinear', 'bicubic'])
    out = F.interpolate(out, scale_factor=scale, mode=mode)
    # 1.3 执行add noise
    gray_noise_prob = self.opt['gray_noise_prob']
    if np.random.uniform() < self.opt['gaussian_noise_prob']:
	out = random_add_gaussian_noise_pt(
	    out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
    else:
	out = random_add_poisson_noise_pt(
	    out,
	    scale_range=self.opt['poisson_scale_range'],
	    gray_prob=gray_noise_prob,
	    clip=True,
	    rounds=False)
    # 1.4 执行JPEG compression
    jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
    out = torch.clamp(out, 0, 1)  # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
    out = self.jpeger(out, quality=jpeg_p)

    # ----------------------- The second degradation process ----------------------- #
    # 2.1 blur
    if np.random.uniform() < self.opt['second_blur_prob']:
	out = filter2D(out, self.kernel2)
    # 2.2 random resize
    updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
    if updown_type == 'up':
	scale = np.random.uniform(1, self.opt['resize_range2'][1])
    elif updown_type == 'down':
	scale = np.random.uniform(self.opt['resize_range2'][0], 1)
    else:
	scale = 1
    mode = random.choice(['area', 'bilinear', 'bicubic'])
    out = F.interpolate(
	out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
    # 2.3 add noise
    gray_noise_prob = self.opt['gray_noise_prob2']
    if np.random.uniform() < self.opt['gaussian_noise_prob2']:
	out = random_add_gaussian_noise_pt(
	    out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
    else:
	out = random_add_poisson_noise_pt(
	    out,
	    scale_range=self.opt['poisson_scale_range2'],
	    gray_prob=gray_noise_prob,
	    clip=True,
	    rounds=False)

    # 2.4 执行JPEG compression和收尾操作
    # 我们还需要将图像调整到所需的大小。我们将[size back + sinc filter]组合在一起操作。
    # 有两个选项可选:
    #   1. [resize back + sinc filter] + JPEG compression
    #   2. JPEG compression + [resize back + sinc filter]
    # 根据经验,我们发现组合(sinc + JPEG + Resize)会引入扭曲的线条。
    if np.random.uniform() < 0.5:
	# resize back + the final sinc filter
	mode = random.choice(['area', 'bilinear', 'bicubic'])
	out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
	out = filter2D(out, self.sinc_kernel)
	# JPEG compression
	jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
	out = torch.clamp(out, 0, 1)
	out = self.jpeger(out, quality=jpeg_p)
    else:
	# JPEG compression
	jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
	out = torch.clamp(out, 0, 1)
	out = self.jpeger(out, quality=jpeg_p)
	# resize back + the final sinc filter
	mode = random.choice(['area', 'bilinear', 'bicubic'])
	out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
	out = filter2D(out, self.sinc_kernel)
......

6. Loss function

        First explain the mathematical symbols:

        x:enter

        \phi: VGG19 model

        y:ground truth

        G: generate model

        D: discriminant model

        y^{r}: The real label is a matrix full of 1s

        y^{f}: The fake label generated by the G model is a matrix full of 0s

1. Generate model loss function

        Generate a model loss function:

LG_{G}=L_{percep}+\lambda LG_{G}+\eta L_{1}

        \lambdaDefault 0.1, anddefault 1

        L_{perceive}: Perceptual loss, send the output of gt and the generated model to the pre-training VGG19 respectively, take conv1_2(bx64x256x256), conv2_2(bx128,128x128), conv3_4(bx256x64x64), conv4_4(bx512x32x32), conv5_4(bx512x16x16) layer data, then Calculate L1loss, the formula is as follows:

L_{percep}=\left \| \phi (x_{i})- \phi (y_{i})\right \|_{1}

        L_{G}: GANLoss, send the output of the generated model to the discriminant model (U-Net), calculate the binary cross entropy loss (BCELoss) with the result (bx1x256x256) and babel (all 1), the formula is as follows:

L_{G}=-(y_{i}^{r}logD(x_{i})) - (1-y_{i}^{r})log(1-D(x_{i}))=-(y_{i}^{r}logD(x_{i}))

        L_{1}: The output of gt and the generated model directly calculates L1loss, the formula is as follows:

L_{1}=mean\left \| G(x_{i})-y_{i} \right \|_{1}

        Code:

# 位置 realesrgan/models/realesrgan_model.py

	# pixel loss
    if self.cri_pix:
	l_g_pix = self.cri_pix(self.output, l1_gt)
	l_g_total += l_g_pix
	loss_dict['l_g_pix'] = l_g_pix
    # perceptual loss
    if self.cri_perceptual:
	l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
	if l_g_percep is not None:
	    l_g_total += l_g_percep
	    loss_dict['l_g_percep'] = l_g_percep
	if l_g_style is not None:
	    l_g_total += l_g_style
	    loss_dict['l_g_style'] = l_g_style
    # gan loss
    fake_g_pred = self.net_d(self.output)
    l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
    l_g_total += l_g_gan
    loss_dict['l_g_gan'] = l_g_gan

    l_g_total.backward()
    self.optimizer_g.step()

2. Discriminant model loss function

        Real-ESRGAN's discriminant model optimization is divided into two steps:

        (1) Optimize the ability to judge truth, that is, construct a BECLoss that is all 1 y^{r}, and then calculate D(y_{i})the sum. The formula is as follows:y^{r}

L_{D}^{r}=-(y_{i}^{r}logD(y_{i})) - (1-y_{i})log(1-D(y_{i}))=-(y_{i}^{r}logD(y_{i}))

        (2) Optimize the ability to distinguish false, that is, construct a BECLoss that is all 0 y^{f}, and then calculate D(y_{i})the sum, the formula is as follows:y^{f}

L_{D}^{f}=-(y_{i}^{f}logD(y_{i})) - (1-y_{i}^{f})log(1-D(y_{i}))=(y_{i}^{f}-1)log(1-D(y_{i}))

        Code:

# 位置 realesrgan/models/realesrgan_model.py
    self.optimizer_d.zero_grad()
    # real
    real_d_pred = self.net_d(gan_gt)
    l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
    loss_dict['l_d_real'] = l_d_real
    loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
    l_d_real.backward()
    # fake
    fake_d_pred = self.net_d(self.output.detach().clone())  # clone for pt1.9
    l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
    loss_dict['l_d_fake'] = l_d_fake
    loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
    l_d_fake.backward()
    self.optimizer_d.step()

        Real-ESRGAN is introduced here, and there are many details about the implementation of Real-ESRGAN, and there will be another issue soon, pay attention to not get lost! ! !

Guess you like

Origin blog.csdn.net/xian0710830114/article/details/131575382