Image super-resolution: Efficient and Degradation-Adaptive Network for Real-World Image Super-Resolution (DASR, oppo)

Efficient and Degradation-Adaptive Network for Real-World Image Super-Resolution

Efficient and Degradation-Adaptive Network for Real-World Image Super-Resolution

Oppo's blind SR method, the idea is to explicitly train and predict degradation. And apply to main sr net.

1. This article mainly solves the problem

Blind SR adapted to various degradation modes.
The following figure y is the real HR, the LR image x is generated through various degradation methods, and then the degradation represetation of x is predicted through a branch network. The degradation
method is integrated into the main super-resolution network

insert image description here

2. Network structure

2.1 degradation prediction stage

Loss function: v is a representation of the degradation type. The degradation method is explained later in this article.

insert image description here

  1. Degradation prediction network 6 convolutions + 1 pooling layer -> batch * 33
  2. condition net(map) Two-layer fully connected network -> batch * 5 The five numbers here represent the weight of the expert in the main super-resolution network

insert image description here

class Degradation_Predictor(nn.Module):
    def __init__(self, in_nc=3, nf=64, num_params=100, num_networks=5, use_bias=True):
        super(Degradation_Predictor, self).__init__()

        self.ConvNet = nn.Sequential(*[
            nn.Conv2d(in_nc, nf, kernel_size=5, stride=1, padding=2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf, nf, kernel_size=5, stride=1, padding=2, bias=use_bias),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf, nf, kernel_size=5, stride=1, padding=2, bias=use_bias),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf, nf, kernel_size=5, stride=2, padding=2, bias=use_bias),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf, nf, kernel_size=5, stride=1, padding=2, bias=use_bias),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf, num_params, kernel_size=5, stride=1, padding=2, bias=use_bias),
            nn.LeakyReLU(0.2, True),
        ])

        self.globalPooling = nn.AdaptiveAvgPool2d((1, 1))

        self.MappingNet = nn.Sequential(*[
            nn.Linear(num_params, 15),
            nn.Linear(15, num_networks),
        ])

    def forward(self, input):
        conv = self.ConvNet(input)
        flat = self.globalPooling(conv)
        out_params = flat.view(flat.size()[:2])
        mapped_weights = self.MappingNet(out_params)
        return out_params, mapped_weights

2.2 main sr net

insert image description here

# 动态卷积,就是输入 feature 和 5个conv expert的weight, 
# 5个conv expert 通过weight加权融合后 得到最终的weight, 然后对 feature进行卷积。
class Dynamic_conv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=1, dilation=1, groups=1, if_bias=True, K=5, init_weight=False):
        super(Dynamic_conv2d, self).__init__()
        assert in_planes % groups == 0
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.if_bias = if_bias
        self.K = K

        self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True)
        if self.if_bias:
            self.bias = nn.Parameter(torch.Tensor(K, out_planes), requires_grad=True)
        else:
            self.bias = None
        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for i in range(self.K):
            nn.init.kaiming_uniform_(self.weight[i])
            if self.if_bias:
                nn.init.constant_(self.bias[i], 0)

    def forward(self, inputs):
        x = inputs['x']
        softmax_attention = inputs['weights']
        batch_size, in_planes, height, width = x.size()
        x = x.contiguous().view(1, -1, height, width)
        weight = self.weight.view(self.K, -1)

        aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size, self.kernel_size)
        if self.bias is not None:
            aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
            output = F.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
                              dilation=self.dilation, groups=self.groups*batch_size)
        else:
            output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                              dilation=self.dilation, groups=self.groups * batch_size)

        output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
        return output

After MSRResNet convolution is converted to dynamic convolution, a dynamic network is obtained



class MSRResNetDynamic(nn.Module):

    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, num_models=5, upscale=4):
        super(MSRResNetDynamic, self).__init__()
        self.upscale = upscale

        self.conv_first = Dynamic_conv2d(num_in_ch, num_feat, 3, groups=1, if_bias=True, K=num_models)
        self.body = make_layer(ResidualBlockNoBNDynamic, num_block, num_feat=num_feat, num_models=num_models)

        # upsampling
        if self.upscale in [2, 3]:
            self.upconv1 = Dynamic_conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, groups=1, if_bias=True, K=num_models)
            self.pixel_shuffle = nn.PixelShuffle(self.upscale)
        elif self.upscale == 4:
            self.upconv1 = Dynamic_conv2d(num_feat, num_feat * 4, 3, groups=1, if_bias=True, K=num_models)
            self.upconv2 = Dynamic_conv2d(num_feat, num_feat * 4, 3, groups=1, if_bias=True, K=num_models)
            self.pixel_shuffle = nn.PixelShuffle(2)

        self.conv_hr = Dynamic_conv2d(num_feat, num_feat, 3, groups=1, if_bias=True, K=num_models)
        self.conv_last = Dynamic_conv2d(num_feat, num_out_ch, 3, groups=1, if_bias=True, K=num_models)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)


    def forward(self, x, weights):
        out = self.lrelu(self.conv_first({
    
    'x': x, 'weights': weights}))
        out = self.body({
    
    'x': out, 'weights': weights})['x']

        if self.upscale == 4:
            out = self.lrelu(self.pixel_shuffle(self.upconv1({
    
    'x': out, 'weights': weights})))
            out = self.lrelu(self.pixel_shuffle(self.upconv2({
    
    'x': out, 'weights': weights})))
        elif self.upscale in [2, 3]:
            out = self.lrelu(self.pixel_shuffle(self.upconv1({
    
    'x': out, 'weights': weights})))

        out = self.lrelu(self.conv_hr({
    
    'x': out, 'weights': weights}))
        out = self.conv_last({
    
    'x': out, 'weights': weights})
        base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
        out += base
        return out

3. Dataset

1. Training dataset

div2k
flickr2k
ost

Download link: kaggle

Refer to resl-esrgan
to download the three datasets rescale, crop and generate meta info
to get the DF2K_multiscale_sub dataset

2. Test data set

div2k test and RealWorld38: download address

Commonly used bsds100, set5, set14, urban100, etc.: download link

4. degradation strategy

The degradation steps are the same as those introduced in the text. It is mainly divided into three intensity degradation spaces. When processing data sets, the applied probabilities are

degree_list: ['weak_degrade_one_stage', 'standard_degrade_one_stage', 'severe_degrade_two_stage']
degree_prob: [0.3, 0.3, 0.4]

Here we take the standard_degrade_one_stage degradation space as an example, which is S2 in the paper

insert image description here

The kernel that implements
blur is calculated in DASRDataset.
For standard_degrade_one_stage,
it returns:
{'gt': img_gt, 'kernel1': kernel_info,'gt_path': gt_path}

In DASRDataset, only the blur kernel is calculated, and the execution of blur degradation, as well as the execution of resize, noise, jpeg compress, etc., is in the feed_data function in the DASRModel class


    elif self.degradation_degree == 'standard_degrade_one_stage':
        # 第一步是blur, 要根据参数对图像执行blur, 同时得到degradation_params[0:4], 对应论文 v1-v4
        self.degradation_params = torch.zeros(self.opt_train['batch_size_per_gpu'],
                                                self.num_degradation_params)  # [B, 33]

        self.kernel1 = data['kernel1']['kernel'].to(self.device)

        kernel_size_range1 = [self.opt_train['blur_kernel_size_minimum_standard1'],
                                self.opt_train['blur_kernel_size_standard1']]
        rotation_range = [-math.pi, math.pi]
        self.degradation_params[:, self.road_map[0]:self.road_map[0] + 1] = (data['kernel1'][
                                                                                    'kernel_size'].unsqueeze(1) -
                                                                                kernel_size_range1[0]) / (
                                                                                        kernel_size_range1[1] -
                                                                                        kernel_size_range1[0])
        self.degradation_params[:, self.road_map[0] + 1:self.road_map[0] + 2] = (data['kernel1'][
                                                                                        'sigma_x'].unsqueeze(1) -
                                                                                    self.opt_train[
                                                                                        'blur_sigma_standard1'][
                                                                                        0]) / (self.opt_train[
                                                                                                'blur_sigma_standard1'][
                                                                                                1] -
                                                                                            self.opt_train[
                                                                                                'blur_sigma_standard1'][
                                                                                                0])
        self.degradation_params[:, self.road_map[0] + 2:self.road_map[0] + 3] = (data['kernel1'][
                                                                                        'sigma_y'].unsqueeze(1) -
                                                                                    self.opt_train[
                                                                                        'blur_sigma_standard1'][
                                                                                        0]) / (self.opt_train[
                                                                                                'blur_sigma_standard1'][
                                                                                                1] -
                                                                                            self.opt_train[
                                                                                                'blur_sigma_standard1'][
                                                                                                0])
        self.degradation_params[:, self.road_map[0] + 3:self.road_map[0] + 4] = (data['kernel1'][
                                                                                        'rotation'].unsqueeze(1) -
                                                                                    rotation_range[0]) / (
                                                                                            rotation_range[1] -
                                                                                            rotation_range[0])

        ori_h, ori_w = self.gt.size()[2:4]

        # blur
        out = filter2D(self.gt, self.kernel1)
        # 第二步, resize 参数:scale 和下采样方法。
        # random resize
        updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob_standard1'])[0]
        if updown_type == 'up':
            scale = np.random.uniform(1, self.opt['resize_range_standard1'][1])
        elif updown_type == 'down':
            scale = np.random.uniform(self.opt['resize_range_standard1'][0], 1)
        else:
            scale = 1
        mode = random.choice(self.resize_mode_list)
        out = F.interpolate(out, scale_factor=scale, mode=mode)
        normalized_scale = (scale - self.opt['resize_range_standard1'][0]) / (
                    self.opt['resize_range_standard1'][1] - self.opt['resize_range_standard1'][0])
        onehot_mode = torch.zeros(len(self.resize_mode_list))
        for index, mode_current in enumerate(self.resize_mode_list):
            if mode_current == mode:
                onehot_mode[index] = 1
        self.degradation_params[:, self.road_map[1]:self.road_map[1] + 1] = torch.tensor(
            normalized_scale).expand(self.gt.size(0), 1)     # scale
        self.degradation_params[:, self.road_map[1] + 1:self.road_map[1] + 4] = onehot_mode.expand(
            self.gt.size(0), len(self.resize_mode_list))     # resize mode

        # 第三步,添加噪声
        # noise # noise_range: [1, 30] poisson_scale_range: [0.05, 3]
        gray_noise_prob = self.opt['gray_noise_prob_standard1']
        if np.random.uniform() < self.opt['gaussian_noise_prob_standard1']:
            sigma, gray_noise, out, self.noise_g_first = random_add_gaussian_noise_pt(
                out, sigma_range=self.opt['noise_range_standard1'], clip=True, rounds=False,
                gray_prob=gray_noise_prob)

            normalized_sigma = (sigma - self.opt['noise_range_standard1'][0]) / (
                        self.opt['noise_range_standard1'][1] - self.opt['noise_range_standard1'][0])
            self.degradation_params[:, self.road_map[2]:self.road_map[2] + 1] = normalized_sigma.unsqueeze(1)
            self.degradation_params[:, self.road_map[2] + 1:self.road_map[2] + 2] = gray_noise.unsqueeze(1)
            self.degradation_params[:, self.road_map[2] + 2:self.road_map[2] + 4] = torch.tensor([1, 0]).expand(
                self.gt.size(0), 2)
            self.noise_p_first = only_generate_poisson_noise_pt(out, scale_range=self.opt[
                'poisson_scale_range_standard1'], gray_prob=gray_noise_prob)
        else:
            scale, gray_noise, out, self.noise_p_first = random_add_poisson_noise_pt(
                out, scale_range=self.opt['poisson_scale_range_standard1'], gray_prob=gray_noise_prob,
                clip=True, rounds=False)
            normalized_scale = (scale - self.opt['poisson_scale_range_standard1'][0]) / (
                        self.opt['poisson_scale_range_standard1'][1] -
                        self.opt['poisson_scale_range_standard1'][0])
            self.degradation_params[:, self.road_map[2]:self.road_map[2] + 1] = normalized_scale.unsqueeze(1)
            self.degradation_params[:, self.road_map[2] + 1:self.road_map[2] + 2] = gray_noise.unsqueeze(1)
            self.degradation_params[:, self.road_map[2] + 2:self.road_map[2] + 4] = torch.tensor([0, 1]).expand(
                self.gt.size(0), 2)
            self.noise_g_first = only_generate_gaussian_noise_pt(out,
                                                                    sigma_range=self.opt['noise_range_standard1'],
                                                                    gray_prob=gray_noise_prob)
        # 第四步, jpeg 处理,参数只有一个图像质量。另外3个是图像resize 方法(one-hot表示)
        # JPEG compression
        jpeg_p = out.new_zeros(out.size(0)).uniform_(
            *self.opt['jpeg_range_standard1'])  # tensor([61.6463, 94.2723, 37.1205, 34.9564], device='cuda:0')]
        normalized_jpeg_p = (jpeg_p - self.opt['jpeg_range_standard1'][0]) / (
                    self.opt['jpeg_range_standard1'][1] - self.opt['jpeg_range_standard1'][0])
        out = torch.clamp(out, 0, 1)
        out = self.jpeger(out, quality=jpeg_p)
        self.degradation_params[:, self.road_map[3]:self.road_map[3] + 1] = normalized_jpeg_p.unsqueeze(1)

        # resize back
        mode = random.choice(self.resize_mode_list)
        onehot_mode = torch.zeros(len(self.resize_mode_list))
        for index, mode_current in enumerate(self.resize_mode_list):
            if mode_current == mode:
                onehot_mode[index] = 1
        out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
        self.degradation_params[:, self.road_map[3] + 4:] = onehot_mode.expand(self.gt.size(0),
                                                                                len(self.resize_mode_list))

        self.degradation_params = self.degradation_params.to(self.device)

        # clamp and round
        self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.

        # random crop
        gt_size = self.opt['gt_size']
        self.gt, self.lq, self.top, self.left = paired_random_crop_return_indexes(self.gt, self.lq, gt_size,
                                                                                    self.opt['scale'])

degradation_params is a 33dim vector and also the gt of the regression loss function in the degradation prediction network.

5. Loss function

1. Pixel reconstruction loss L1 loss:

def l1_loss(pred, target):
    return F.l1_loss(pred, target, reduction='none')

2. The degradation regression loss function is also L1 loss

3. The perceptual loss configuration is as follows:

perceptual_opt:
    type: PerceptualLoss
    layer_weights:
      # before relu
      'conv1_2': 0.1
      'conv2_2': 0.1
      'conv3_4': 1
      'conv4_4': 1
      'conv5_4': 1
    vgg_type: vgg19
    use_input_norm: true
    perceptual_weight: !!float 1
    style_weight: 0
    range_norm: false
    criterion: l1

The implementation is as follows:
specify some layers of vgg net and the corresponding weights
input prediction and gt, and calculate the perceptual loss and style loss between layer features.

class PerceptualLoss(nn.Module):
    """Perceptual loss with commonly used style loss.

    Args:
        layer_weights (dict): The weight for each layer of vgg feature.
            Here is an example: {'conv5_4': 1.}, which means the conv5_4
            feature layer (before relu5_4) will be extracted with weight
            1.0 in calculting losses.
        vgg_type (str): The type of vgg network used as feature extractor.
            Default: 'vgg19'.
        use_input_norm (bool):  If True, normalize the input image in vgg.
            Default: True.
        range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
            Default: False.
        perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
            loss will be calculated and the loss will multiplied by the
            weight. Default: 1.0.
        style_weight (float): If `style_weight > 0`, the style loss will be
            calculated and the loss will multiplied by the weight.
            Default: 0.
        criterion (str): Criterion used for perceptual loss. Default: 'l1'.
    """

    def __init__(self,
                 layer_weights,
                 vgg_type='vgg19',
                 use_input_norm=True,
                 range_norm=False,
                 perceptual_weight=1.0,
                 style_weight=0.,
                 criterion='l1'):
        super(PerceptualLoss, self).__init__()
        self.perceptual_weight = perceptual_weight
        self.style_weight = style_weight
        self.layer_weights = layer_weights
        self.vgg = VGGFeatureExtractor(
            layer_name_list=list(layer_weights.keys()),
            vgg_type=vgg_type,
            use_input_norm=use_input_norm,
            range_norm=range_norm)

        self.criterion_type = criterion
        if self.criterion_type == 'l1':
            self.criterion = torch.nn.L1Loss()
        elif self.criterion_type == 'l2':
            self.criterion = torch.nn.L2loss()
        elif self.criterion_type == 'fro':
            self.criterion = None
        else:
            raise NotImplementedError(f'{
      
      criterion} criterion has not been supported.')

    def forward(self, x, gt):
        """Forward function.

        Args:
            x (Tensor): Input tensor with shape (n, c, h, w).
            gt (Tensor): Ground-truth tensor with shape (n, c, h, w).

        Returns:
            Tensor: Forward results.
        """
        # extract vgg features
        x_features = self.vgg(x)
        gt_features = self.vgg(gt.detach())

        # calculate perceptual loss
        if self.perceptual_weight > 0:
            percep_loss = 0
            for k in x_features.keys():
                if self.criterion_type == 'fro':
                    percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
                else:
                    percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
            percep_loss *= self.perceptual_weight
        else:
            percep_loss = None

        # calculate style loss
        if self.style_weight > 0:
            style_loss = 0
            for k in x_features.keys():
                if self.criterion_type == 'fro':
                    style_loss += torch.norm(
                        self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
                else:
                    style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * self.layer_weights[k]
            style_loss *= self.style_weight
        else:
            style_loss = None

        return percep_loss, style_loss

    def _gram_mat(self, x):
        """Calculate Gram matrix.

        Args:
            x (torch.Tensor): Tensor with shape of (n, c, h, w).

        Returns:
            torch.Tensor: Gram matrix.
        """
        n, c, h, w = x.size()
        features = x.view(n, c, w * h)
        features_t = features.transpose(1, 2)
        gram = features.bmm(features_t) / (c * h * w)
        return gram

4. gan loss function

```bash
gan_opt:
    type: GANLoss
    gan_type: vanilla
    real_label_val: 1.0
    fake_label_val: 0.0
    loss_weight: !!float 1e-1
```

For this article, it is actually a two-category loss nn.BCEWithLogitsLoss()

    class GANLoss(nn.Module):
        """Define GAN loss.

        Args:
            gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
            real_label_val (float): The value for real label. Default: 1.0.
            fake_label_val (float): The value for fake label. Default: 0.0.
            loss_weight (float): Loss weight. Default: 1.0.
                Note that loss_weight is only for generators; and it is always 1.0
                for discriminators.
        """

        def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
            super(GANLoss, self).__init__()
            self.gan_type = gan_type
            self.loss_weight = loss_weight
            self.real_label_val = real_label_val
            self.fake_label_val = fake_label_val

            if self.gan_type == 'vanilla':
                self.loss = nn.BCEWithLogitsLoss()
            elif self.gan_type == 'lsgan':
                self.loss = nn.MSELoss()
            elif self.gan_type == 'wgan':
                self.loss = self._wgan_loss
            elif self.gan_type == 'wgan_softplus':
                self.loss = self._wgan_softplus_loss
            elif self.gan_type == 'hinge':
                self.loss = nn.ReLU()
            else:
                raise NotImplementedError(f'GAN type {
      
      self.gan_type} is not implemented.')

        def _wgan_loss(self, input, target):
            """wgan loss.

            Args:
                input (Tensor): Input tensor.
                target (bool): Target label.

            Returns:
                Tensor: wgan loss.
            """
            return -input.mean() if target else input.mean()

        def _wgan_softplus_loss(self, input, target):
            """wgan loss with soft plus. softplus is a smooth approximation to the
            ReLU function.

            In StyleGAN2, it is called:
                Logistic loss for discriminator;
                Non-saturating loss for generator.

            Args:
                input (Tensor): Input tensor.
                target (bool): Target label.

            Returns:
                Tensor: wgan loss.
            """
            return F.softplus(-input).mean() if target else F.softplus(input).mean()

        def get_target_label(self, input, target_is_real):
            """Get target label.

            Args:
                input (Tensor): Input tensor.
                target_is_real (bool): Whether the target is real or fake.

            Returns:
                (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
                    return Tensor.
            """

            if self.gan_type in ['wgan', 'wgan_softplus']:
                return target_is_real
            target_val = (self.real_label_val if target_is_real else self.fake_label_val)
            return input.new_ones(input.size()) * target_val # target 全为1

        def forward(self, input, target_is_real, is_disc=False):
            """
            Args:
                input (Tensor): The input for the loss module, i.e., the network
                    prediction.
                target_is_real (bool): Whether the targe is real or fake.
                is_disc (bool): Whether the loss for discriminators or not.
                    Default: False.

            Returns:
                Tensor: GAN loss value.
            """
            target_label = self.get_target_label(input, target_is_real) #
            if self.gan_type == 'hinge':
                if is_disc:  # for discriminators in hinge-gan
                    input = -input if target_is_real else input
                    loss = self.loss(1 + input).mean()
                else:  # for generators in hinge-gan
                    loss = -input.mean()
            else:  # other gan types
                loss = self.loss(input, target_label)

            # loss_weight is always 1.0 for discriminators
            return loss if is_disc else loss * self.loss_weight

6. Discriminator

There is no discriminator during testing, but there is a discriminator during training.
The discriminator is a regular U-net network

class UNetDiscriminatorSN(nn.Module):
    """Defines a U-Net discriminator with spectral normalization (SN)"""

    def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
        super(UNetDiscriminatorSN, self).__init__()
        self.skip_connection = skip_connection
        norm = spectral_norm

        self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)

        self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
        self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
        self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
        # upsample
        self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
        self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
        self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))

        # extra
        self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
        self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))

        self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)

    def forward(self, x):
        x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
        x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
        x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
        x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)

        # upsample
        x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
        x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)

        if self.skip_connection:
            x4 = x4 + x2
        x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
        x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)

        if self.skip_connection:
            x5 = x5 + x1
        x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
        x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)

        if self.skip_connection:
            x6 = x6 + x0

        # extra
        out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
        out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
        out = self.conv9(out)

        return out

U-net mainly includes 10 convolutional layers (spectral norm, leaky_relu)
when the input is [2, 3, 512, 512], the output shape is as follows

insert image description here

Have a general understanding of the discriminator, the input is a 3-channel image, and the output is a single-channel equal-size map

7. Training steps

Part of the network of the generator is the loaded pre-trained model msrresnet. If it is trained from scratch, it may not converge. Of course, you can also use pixel loss to train a pretrained weight first.

Code comments are added in the comments below

 def optimize_parameters(self, current_iter):
        # 一次迭代步骤的优化。优化一次生成器,接着优化一次判别器。
        # optimize net_g
        # 1. 首先优化 生成网络net_g, net_d判别网络不更新weight
        for p in self.net_d.parameters():
            p.requires_grad = False

        # 2. 梯度归0
        self.optimizer_g.zero_grad()
        # 3. 前向生成网络,输入的是一个低质低分辨率图像
        # predicted_params, weights分别是33dim的退化类型参数,net_g的动态卷积参数
        # 图像先经过退化网络预测退化,并融入超分生成网络,生成超分图像output
        predicted_params, weights = self.net_p(self.lq)
        self.output = self.net_g(self.lq.contiguous(), weights)
        # 4. 计算训练生成网络的损失
        # 主要包括 pixel loss 重建损失 self.cri_pix(self.output, self.gt)
        # 主要包括 退化预测回归损失 self.cri_regress(predicted_params, self.degradation_params)
        # 图像内容和风格感知损失    self.cri_perceptual(self.output, self.gt)
        # gan损失,使预测迷惑判别器 self.cri_gan(fake_g_pred, True, is_disc=False)
        l_g_total = 0
        loss_dict = OrderedDict()
        if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
            # pixel loss
            if self.cri_pix:
                l_pix = self.cri_pix(self.output, self.gt)
                l_g_total += l_pix
                loss_dict['l_pix'] = l_pix
            if self.cri_regress:
                l_regression = self.cri_regress(predicted_params, self.degradation_params)
                l_g_total += l_regression
                loss_dict['l_regression'] = l_regression
            # perceptual loss
            if self.cri_perceptual:
                l_percep, l_style = self.cri_perceptual(self.output, self.gt)
                if l_percep is not None:
                    l_g_total += l_percep
                    loss_dict['l_percep'] = l_percep
                if l_style is not None:
                    l_g_total += l_style
                    loss_dict['l_style'] = l_style
            # gan loss
            fake_g_pred = self.net_d(self.output)
            l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
            l_g_total += l_g_gan
            loss_dict['l_g_gan'] = l_g_gan
            # 5. 计算梯度和优化
            l_g_total.backward()
            self.optimizer_g.step()

        
        # optimize net_d
        # 6. 优化判别器网络,首先requires_grad设为ture,可训练
        for p in self.net_d.parameters():
            p.requires_grad = True
        # 7. 梯度归0
        self.optimizer_d.zero_grad()

        
        # real
        # 8. 计算gt进入判别器的损失,使gt 尽量为 1
        real_d_pred = self.net_d(self.gt)
        l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
        loss_dict['l_d_real'] = l_d_real
        loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
        l_d_real.backward()
        # fake
        # 9. 计算gt进入判别器的损失,使predict output 尽量为 0
        fake_d_pred = self.net_d(self.output.detach())
        l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
        loss_dict['l_d_fake'] = l_d_fake
        loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())

        # 10. 梯度计算和优化
        l_d_fake.backward()
        self.optimizer_d.step()

        self.log_dict = self.reduce_loss_dict(loss_dict)

        if self.ema_decay > 0:
            self.model_ema(decay=self.ema_decay)

8. Outlook

The degradation model in this paper is some common degradation types defined by ourselves, but is it consistent with the actual situation?
Not necessarily. In the future, it can be combined with some degradation models to achieve better results. For example, through Rethinking Noise Synthesis and Modeling in Raw Denoising to model camera noise, including other fuzzy models and sampling models, it mainly depends on your own application. what type is it. If you can model your own image degradation, it is also possible to cooperate with the method of this paper.

Guess you like

Origin blog.csdn.net/tywwwww/article/details/128036503#comments_26429931