实验——基于pytorch的blind restoration联合网络训练

版权声明: https://blog.csdn.net/gwplovekimi/article/details/85688031

加噪和超分的处理可以参考本人github代码https://github.com/gwpscut/degradation-model-for-image-restoration

去噪和超分网络都可以参考本人之前的博文哈

subnetwork为DnCnn,主网络为SRResnet。subnetwork输出为noise level map。注意,博文《基于pytorch的超分和去噪网络联合fine tuning》里面采用的subnetwork输出为clean image。

setting

{
  "name": "finetune_all_subnetc16s06_basic_resnet_DIV2K",
  "tb_logger_dir": "sr_c16s06",
  "use_tb_logger": true,
  "model": "sr_sub",
  "scale": 4,
  "crop_scale": 0,
  "gpu_ids": [
    3,
    5
  ],
  "datasets": {
    "train": {
      "name": "DIV2K",
      "mode": "LRMRHR",
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub",
      "dataroot_MR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4_residualALL",
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4_noiseALL",
      "subset_file": null,
      "use_shuffle": true,
      "n_workers": 8,
      "batch_size": 24,
      "HR_size": 128,
      "use_flip": true,
      "use_rot": true,
      "phase": "train",
      "scale": 4,
      "data_type": "img"
    },
    "val": {
      "name": "val_set5_x4_c03s08_mod4",
      "mode": "LRMRHR",
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5",
      "dataroot_MR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4_residualALL",
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4_noiseALL",
      "phase": "val",
      "scale": 4,
      "data_type": "img"
    }
  },
  "path": {
    "root": "/home/guanwp/jingwen/sr_c16s06",
    "pretrain_model_sub": "/home/guanwp/jingwen/sr/experiments/LR_x4_subnet_residual_DIV2K_guan/models/51000_G.pth",
    "experiments_root": "/home/guanwp/jingwen/sr_c16s06/experiments/finetune_all_subnetc16s06_basic_resnet_DIV2K",
    "models": "/home/guanwp/jingwen/sr_c16s06/experiments/finetune_all_subnetc16s06_basic_resnet_DIV2K/models",
    "log": "/home/guanwp/jingwen/sr_c16s06/experiments/finetune_all_subnetc16s06_basic_resnet_DIV2K",
    "val_images": "/home/guanwp/jingwen/sr_c16s06/experiments/finetune_all_subnetc16s06_basic_resnet_DIV2K/val_images"
  },
  "network_G": {
    "which_model_G": "sr_resnet",
    "norm_type": null,
    "mode": "CNA",
    "nf": 64,
    "nb": 16,
    "in_nc": 6,
    "out_nc": 3,
    "group": 1,
    "scale": 4
  },
  "network_sub": {
    "which_model_sub": "noise_subnet",
    "norm_type": "batch",
    "mode": "CNA",
    "nf": 64,
    "in_nc": 3,
    "out_nc": 3,
    "group": 1
  },
  "train": {
    "lr_G": 0.0001,
    "lr_scheme": "MultiStepLR",
    "lr_steps": [
      500000
    ],
    "lr_gamma": 0.1,
    "pixel_criterion_basic": "l2",
    "pixel_criterion_noise": "l2",
    "pixel_weight_basic": 1.0,
    "pixel_weight_noise": 1.0,
    "val_freq": 2000.0,
    "manual_seed": 0,
    "niter": 1000000.0
  },
  "logger": {
    "print_freq": 200,
    "save_checkpoint_freq": 2000.0
  },
  "timestamp": "190129-133631",
  "is_train": true,
  "adabn": null
}

model

import os
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.optim import lr_scheduler

import models.networks as networks
from .base_model import BaseModel


class SRModel(BaseModel):
    def __init__(self, opt):
        super(SRModel, self).__init__(opt)
        train_opt = opt['train']
        finetune_type = opt['finetune_type']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        self.subnet = networks.define_sub(opt).to(self.device)
        self.load()

        if self.is_train:
            self.netG.train()
            self.subnet.train()
            # self.subnet.eval()

            # loss
            loss_type_noise = train_opt['pixel_criterion_noise']
            if loss_type_noise == 'l1':
                self.cri_pix_noise = nn.L1Loss().to(self.device)
            elif loss_type_noise == 'l2':
                self.cri_pix_noise = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type_noise))
            self.l_pix_noise_w = train_opt['pixel_weight_noise']

            loss_type_basic = train_opt['pixel_criterion_basic']
            if loss_type_basic == 'l1':
                self.cri_pix_basic = nn.L1Loss().to(self.device)
            elif loss_type_basic == 'l2':
                self.cri_pix_basic = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type_basic))
            self.l_pix_basic_w = train_opt['pixel_weight_basic']

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0

            self.optim_params = self.__define_grad_params(finetune_type)

            self.optimizer_G = torch.optim.Adam(
                self.optim_params, lr=train_opt['lr_G'], weight_decay=wd_G)
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')

    def feed_data(self, data, need_HR=True):
        self.var_L = data['LR'].to(self.device)  # LR
        self.real_H = data['HR'].to(self.device)  # HR
        self.mid_L = data['MR'].to(self.device)  # MR
        # self.real_noise = (data['LR']-data['HR']).to(self.device)

    def __define_grad_params(self, finetune_type=None):

        optim_params = []

        if finetune_type == 'sft':
            for k, v in self.netG.named_parameters():
                v.requires_grad = False
                if k.find('Gate') >= 0:
                    v.requires_grad = True
                    optim_params.append(v)
                    print('we only optimize params: {}'.format(k))
        elif finetune_type == 'sub_sft':
            for k, v in self.netG.named_parameters():
                v.requires_grad = False
                if k.find('Gate') >= 0:
                    v.requires_grad = True
                    optim_params.append(v)
                    print('we only optimize params: {}'.format(k))
            for k, v in self.subnet.named_parameters():  # can optimize for a part of the model
                v.requires_grad = False
                if k.find('degration') >= 0:
                    v.requires_grad = True
                    optim_params.append(v)
                    print('we only optimize params: {}'.format(k))
        elif finetune_type == 'basic' or finetune_type == 'sft_basic':
            for k, v in self.netG.named_parameters():
                v.requires_grad = True
                optim_params.append(v)
                print('we only optimize params: {}'.format(k))
            for k, v in self.subnet.named_parameters():
                v.requires_grad = False
        else:
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                    print('params [{:s}] will optimize.'.format(k))
                else:
                    print('WARNING: params [{:s}] will not optimize.'.format(k))
            for k, v in self.subnet.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                    print('params [{:s}] will optimize.'.format(k))
                else:
                    print('WARNING: params [{:s}] will not optimize.'.format(k))
        return optim_params

    def optimize_parameters(self, step):

        self.optimizer_G.zero_grad()

        self.fake_noise = self.subnet(self.var_L)
        l_pix_noise = self.l_pix_noise_w * self.cri_pix_noise(self.fake_noise, self.mid_L)
        self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise), 1))
        # self.fake_H = self.netG((self.var_L, self.fake_noise))
        l_pix_basic = self.l_pix_basic_w * self.cri_pix_basic(self.fake_H, self.real_H)
        l_pix = l_pix_noise + l_pix_basic
        l_pix.backward()

        # self.fake_noise = self.subnet(self.var_L)
        # # self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise), 1))
        # self.fake_H = self.netG((self.var_L, self.fake_noise))
        # l_pix = self.l_pix_basic_w * self.cri_pix_basic(self.fake_H, self.real_H)
        # l_pix.backward()

        self.optimizer_G.step()

        self.log_dict['l_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        self.subnet.eval()
        if self.is_train:
            for v in self.optim_params:
                v.requires_grad = False
        else:
            for k, v in self.netG.named_parameters():
                v.requires_grad = False
            for k, v in self.subnet.named_parameters():
                v.requires_grad = False
        self.fake_noise = self.subnet(self.var_L)
        self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise), 1))
        # self.fake_H = self.netG((self.var_L, self.fake_noise))
        if self.is_train:
            for v in self.optim_params:
                v.requires_grad = True
        else:
            for k, v in self.netG.named_parameters():
                v.requires_grad = True
            for k, v in self.subnet.named_parameters():
                v.requires_grad = True
        self.netG.train()
        self.subnet.train()
        # self.subnet.eval()

    # def test(self):
    #     self.netG.eval()
    #     for k, v in self.netG.named_parameters():
    #         v.requires_grad = False
    #     self.fake_H = self.netG(self.var_L)
    #     for k, v in self.netG.named_parameters():
    #         v.requires_grad = True
    #     self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach()[0].float().cpu()
        out_dict['MR'] = self.fake_noise.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if need_HR:
            out_dict['HR'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # G
        s, n = self.get_network_description(self.netG)
        print('Number of parameters in G: {:,d}'.format(n))
        if self.is_train:
            message = '-------------- Generator --------------\n' + s + '\n'
            network_path = os.path.join(self.save_dir, '../', 'network.txt')
            with open(network_path, 'w') as f:
                f.write(message)

            # subnet
            s, n = self.get_network_description(self.subnet)
            print('Number of parameters in subnet: {:,d}'.format(n))
            message = '\n\n\n-------------- subnet --------------\n' + s + '\n'
            with open(network_path, 'a') as f:
                f.write(message)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        load_path_sub = self.opt['path']['pretrain_model_sub']
        if load_path_G is not None:
            print('loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG)
        if load_path_sub is not None:
            print('loading model for subnet [{:s}] ...'.format(load_path_sub))
            self.load_network(load_path_sub, self.subnet)

    def save(self, iter_label):
        self.save_network(self.save_dir, self.netG, 'G', iter_label)
        self.save_network(self.save_dir, self.subnet, 'sub', iter_label)

network

import functools
import torch
import torch.nn as nn
from torch.nn import init

import models.modules.architecture as arch
import models.modules.sft_arch as sft_arch

####################
# initialize
####################


def weights_init_normal(m, std=0.02):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.normal_(m.weight.data, 0.0, std)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.normal_(m.weight.data, 0.0, std)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, std)  # BN also uses norm
        init.constant_(m.bias.data, 0.0)


def weights_init_kaiming(m, scale=1):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
        m.weight.data *= scale
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
        m.weight.data *= scale
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1:
        init.constant_(m.weight.data, 1.0)
        init.constant_(m.bias.data, 0.0)
    # elif classname.find('AdaptiveConvResNorm') != -1:
    #     init.constant_(m.weight.data, 0.0)
    #     if m.bias is not None:
    #         m.bias.data.zero_()


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.orthogonal_(m.weight.data, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.orthogonal_(m.weight.data, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.constant_(m.weight.data, 1.0)
        init.constant_(m.bias.data, 0.0)


def init_weights(net, init_type='kaiming', scale=1, std=0.02):
    # scale for 'kaiming', std for 'normal'.
    print('initialization method [{:s}]'.format(init_type))
    if init_type == 'normal':
        weights_init_normal_ = functools.partial(weights_init_normal, std=std)
        net.apply(weights_init_normal_)
    elif init_type == 'kaiming':
        weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)
        net.apply(weights_init_kaiming_)
    elif init_type == 'orthogonal':
        net.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))


####################
# define network
####################

# Generator
def define_G(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'sr_resnet':  # SRResNet
        netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'modulate_sr_resnet':
        netG = arch.ModulateSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],
                                     upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'],
                                     upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'],
                                     gate_conv_bias=opt_net['gate_conv_bias'])

    elif which_model == 'arcnn':
        netG = arch.ARCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
                             norm_type=opt_net['norm_type'], mode=opt_net['mode'], ada_ksize=opt_net['ada_ksize'])

    elif which_model == 'srcnn':
        netG = arch.SRCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
                             norm_type=opt_net['norm_type'], mode=opt_net['mode'], ada_ksize=opt_net['ada_ksize'])

    elif which_model == 'denoise_resnet':
        netG = arch.DenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],
                                  upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'],
                                  upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'],
                                  down_scale=opt_net['down_scale'], fea_norm=opt_net['fea_norm'],
                                  upsample_norm=opt_net['upsample_norm'])
    elif which_model == 'modulate_denoise_resnet':
        netG = arch.ModulateDenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],
                                          upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'],
                                          upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'],
                                          gate_conv_bias=opt_net['gate_conv_bias'])
    elif which_model == 'noise_subnet':
        netG = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],
                                norm_type=opt_net['norm_type'], mode=opt_net['mode'])
    elif which_model == 'cond_denoise_resnet':
        netG = arch.CondDenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],
                                      upscale=opt_net['scale'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'],
                                      down_scale=opt_net['down_scale'], num_classes=opt_net['num_classes'],
                                      norm_type=opt_net['norm_type'])

    elif which_model == 'adabn_denoise_resnet':
        netG = arch.AdaptiveDenoiseResNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'], nb=opt_net['nb'],
                                          upscale=opt_net['scale'], down_scale=opt_net['down_scale'])

    elif which_model == 'sft_arch':  # SFT-GAN
        netG = sft_arch.SFT_Net()

    elif which_model == 'RRDB_net':  # RRDB
        netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
            nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
            act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))

    if opt['init_type'] is not None:
        init_weights(netG, init_type=opt['init_type'], scale=0.1)
    if gpu_ids:
        assert torch.cuda.is_available()
        netG = nn.DataParallel(netG)
    return netG


def define_sub(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_sub']
    which_model = opt_net['which_model_sub']

    if which_model == 'noise_subnet':
        subnet = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],
                                norm_type=opt_net['norm_type'], mode=opt_net['mode'])
    else:
        raise NotImplementedError('subnet model [{:s}] not recognized'.format(which_model))

    if gpu_ids:
        assert torch.cuda.is_available()
        subnet = nn.DataParallel(subnet)
    return subnet


# Discriminator
def define_D(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_D']
    which_model = opt_net['which_model_D']

    if which_model == 'discriminator_vgg_128':
        netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
            norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])

    elif which_model == 'dis_acd':  # sft-gan, Auxiliary Classifier Discriminator
        netD = sft_arch.ACD_VGG_BN_96()

    elif which_model == 'discriminator_vgg_96':
        netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
            norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
    elif which_model == 'discriminator_vgg_192':
        netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
            norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
    elif which_model == 'discriminator_vgg_128_SN':
        netD = arch.Discriminator_VGG_128_SN()
    else:
        raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))

    init_weights(netD, init_type='kaiming', scale=1)
    if gpu_ids:
        netD = nn.DataParallel(netD)
    return netD


def define_F(opt, use_bn=False):
    gpu_ids = opt['gpu_ids']
    device = torch.device('cuda' if gpu_ids else 'cpu')
    # pytorch pretrained VGG19-54, before ReLU.
    if use_bn:
        feature_layer = 49
    else:
        feature_layer = 34
    netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
        use_input_norm=True, device=device)
    # netF = arch.ResNet101FeatureExtractor(use_input_norm=True, device=device)
    if gpu_ids:
        netF = nn.DataParallel(netF)
    netF.eval()  # No need to train
    return netF

architecture.py

import math
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from . import block as B
from . import spectral_norm as SN
from . import adaptive_norm as AN

####################
# Generator
####################


class SRCNN(nn.Module):
    def __init__(self, in_nc, out_nc, nf, norm_type='batch', act_type='relu', mode='CNA', ada_ksize=None):
        super(SRCNN, self).__init__()

        fea_conv = B.conv_block(in_nc, nf, kernel_size=9, norm_type=norm_type, act_type=act_type, mode=mode
                                , ada_ksize=ada_ksize)
        mapping_conv = B.conv_block(nf, nf // 2, kernel_size=1, norm_type=norm_type, act_type=act_type,
                                    mode=mode, ada_ksize=ada_ksize)
        HR_conv = B.conv_block(nf // 2, out_nc, kernel_size=5, norm_type=norm_type, act_type=None,
                               mode=mode, ada_ksize=ada_ksize)

        self.model = B.sequential(fea_conv, mapping_conv, HR_conv)

    def forward(self, x):
        x = self.model(x)
        return x


class ARCNN(nn.Module):
    def __init__(self, in_nc, out_nc, nf, norm_type='batch', act_type='relu', mode='CNA', ada_ksize=None):
        super(ARCNN, self).__init__()

        fea_conv = B.conv_block(in_nc, nf, kernel_size=9, norm_type=norm_type, act_type=act_type, mode=mode
                                , ada_ksize=ada_ksize)
        conv1 = B.conv_block(nf, nf // 2, kernel_size=7, norm_type=norm_type, act_type=act_type,
                             mode=mode, ada_ksize=ada_ksize)
        conv2 = B.conv_block(nf // 2, nf // 4, kernel_size=1, norm_type=norm_type, act_type=act_type,
                             mode=mode, ada_ksize=ada_ksize)
        HR_conv = B.conv_block(nf // 4, out_nc, kernel_size=5, norm_type=norm_type, act_type=None,
                               mode=mode, ada_ksize=ada_ksize)

        self.model = B.sequential(fea_conv, conv1, conv2, HR_conv)

    def forward(self, x):
        x = self.model(x)
        return x


class SRResNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):
        super(SRResNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
        resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\
            mode=mode, res_scale=res_scale) for _ in range(nb)]
        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        x = self.model(x)
        return x


class ModulateSRResNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='sft', act_type='relu',
                 mode='CNA', res_scale=1, upsample_mode='upconv', gate_conv_bias=True, ada_ksize=None):
        super(ModulateSRResNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, stride=1)
        resnet_blocks = [B.TwoStreamSRResNet(nf, nf, nf, norm_type=norm_type, act_type=act_type,
                         mode=mode, res_scale=res_scale, gate_conv_bias=gate_conv_bias,
                                             ada_ksize=ada_ksize, input_dim=in_nc) for _ in range(nb)]

        self.LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode)
        if norm_type == 'sft':
            self.LR_norm = AN.GateNonLinearLayer(in_nc, conv_bias=gate_conv_bias)
        elif norm_type == 'sft_conv':
            self.LR_norm = AN.MetaLayer(in_nc, conv_bias=gate_conv_bias, kernel_size=ada_ksize)

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.norm_branch = B.sequential(*resnet_blocks)
        self.HR_branch = B.sequential(*upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        fea = self.fea_conv(x[0])
        fea_res_block, _ = self.norm_branch((fea, x[1]))
        fea_LR = self.LR_conv(fea_res_block)
        res = self.LR_norm((fea_LR, x[1]))
        out = self.HR_branch(fea+res)
        return out


class DenoiseResNet(nn.Module):
    """
    jingwen's addition
    denoise Resnet
    """
    def __init__(self, in_nc, out_nc, nf, nb, upscale=1, norm_type='batch', act_type='relu',
                 mode='NAC', res_scale=1, upsample_mode='upconv', ada_ksize=None, down_scale=2,
                 fea_norm=None, upsample_norm=None):
        super(DenoiseResNet, self).__init__()
        n_upscale = int(math.log(down_scale, 2))
        if down_scale == 3:
            n_upscale = 1

        fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=fea_norm, act_type=None, stride=down_scale,
                                ada_ksize=ada_ksize)
        resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,
                         mode=mode, res_scale=res_scale, ada_ksize=ada_ksize) for _ in range(nb)]
        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode
                               , ada_ksize=ada_ksize)
        # LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode
        #                        , ada_ksize=ada_ksize)

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)

        if down_scale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type, norm_type=upsample_norm, ada_ksize=ada_ksize)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type, norm_type=upsample_norm, ada_ksize=ada_ksize) for _ in range(n_upscale)]

        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=upsample_norm, act_type=act_type, ada_ksize=ada_ksize)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=upsample_norm, act_type=None, ada_ksize=ada_ksize)

        self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),
                                  *upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        x = self.model(x)
        return x


# class ModulateDenoiseResNet(nn.Module):
#     def __init__(self, in_nc, out_nc, nf, nb, upscale=1, norm_type='sft', act_type='relu',
#                  mode='CNA', res_scale=1, upsample_mode='upconv', gate_conv_bias=False, ada_ksize=None):
#         super(ModulateDenoiseResNet, self).__init__()
#
#         self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, stride=2)
#         resnet_blocks = [B.TwoStreamSRResNet(nf, nf, nf, norm_type=norm_type, act_type=act_type,
#                          mode=mode, res_scale=res_scale, gate_conv_bias=gate_conv_bias,
#                                              ada_ksize=ada_ksize, input_dim=in_nc) for _ in range(nb)]
#         degration_block = [B.conv_block(in_nc, nf, kernel_size=3, norm_type='batch', act_type='relu')]
#         degration_block.extend([B.conv_block(nf, nf, kernel_size=3, norm_type='batch', act_type='relu')
#                                 for _ in range(15)])
#         degration_block.append(B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None))
#
#         LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode)
#         if norm_type == 'sft':
#             LR_norm = AN.GateNonLinearLayer(in_nc, conv_bias=gate_conv_bias)
#         elif norm_type == 'sft_conv':
#             LR_norm = AN.MetaLayer(in_nc, conv_bias=gate_conv_bias, kernel_size=ada_ksize)
#
#         if upsample_mode == 'upconv':
#             upsample_block = B.upconv_blcok
#         elif upsample_mode == 'pixelshuffle':
#             upsample_block = B.pixelshuffle_block
#         else:
#             raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
#         upsampler = upsample_block(nf, nf, act_type=act_type)
#         HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
#         HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
#
#         self.norm_branch = B.sequential(*resnet_blocks)
#         self.LR_conv = LR_conv
#         self.LR_norm = LR_norm
#         self.degration_block = B.sequential(*degration_block)
#         self.HR_branch = B.sequential(upsampler, HR_conv0, HR_conv1)
#
#     def forward(self, x):
#         fea = self.fea_conv(x)
#         # noise estimation part
#         # deg_estimate = self.degration_block(x) + x
#         deg_estimate = self.degration_block(x)
#         fea_res_block, _ = self.norm_branch((fea, deg_estimate))
#         fea_LR = self.LR_conv(fea_res_block)
#         res = self.LR_norm((fea_LR, deg_estimate))
#         out = self.HR_branch(fea+res)
#         return out


class ModulateDenoiseResNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, upscale=1, norm_type='sft', act_type='relu',
                 mode='CNA', res_scale=1, upsample_mode='upconv', gate_conv_bias=True, ada_ksize=None):
        super(ModulateDenoiseResNet, self).__init__()

        self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, stride=2)
        resnet_blocks = [B.TwoStreamSRResNet(nf, nf, nf, norm_type=norm_type, act_type=act_type,
                         mode=mode, res_scale=res_scale, gate_conv_bias=gate_conv_bias,
                                             ada_ksize=ada_ksize, input_dim=in_nc) for _ in range(nb)]

        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode)
        if norm_type == 'sft':
            LR_norm = AN.GateNonLinearLayer(in_nc, conv_bias=gate_conv_bias)
        elif norm_type == 'sft_conv':
            LR_norm = AN.MetaLayer(in_nc, conv_bias=gate_conv_bias, kernel_size=ada_ksize)

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
        upsampler = upsample_block(nf, nf, act_type=act_type)
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.norm_branch = B.sequential(*resnet_blocks)
        self.LR_conv = LR_conv
        self.LR_norm = LR_norm
        self.HR_branch = B.sequential(upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        fea = self.fea_conv(x[0])
        # noise estimation part
        # deg_estimate = self.degration_block(x) + x
        fea_res_block, _ = self.norm_branch((fea, x[1]))
        fea_LR = self.LR_conv(fea_res_block)
        res = self.LR_norm((fea_LR, x[1]))
        out = self.HR_branch(fea+res)
        return out


class NoiseSubNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, norm_type='batch', act_type='relu', mode='CNA'):
        super(NoiseSubNet, self).__init__()
        degration_block = [B.conv_block(in_nc, nf, kernel_size=3, norm_type=norm_type, act_type=act_type, mode=mode)]
        degration_block.extend([B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=act_type, mode=mode)
                                for _ in range(15)])
        degration_block.append(B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, mode=mode))
        self.degration_block = B.sequential(*degration_block)

    def forward(self, x):
        deg_estimate = self.degration_block(x)
        return deg_estimate


class CondDenoiseResNet(nn.Module):
    """
    jingwen's addition
    denoise Resnet
    """

    def __init__(self, in_nc, out_nc, nf, nb, upscale=1, res_scale=1, down_scale=2, num_classes=1, ada_ksize=None
                 ,upsample_mode='upconv', act_type='relu', norm_type='cond_adaptive_conv_res'):
        super(CondDenoiseResNet, self).__init__()
        n_upscale = int(math.log(down_scale, 2))
        if down_scale == 3:
            n_upscale = 1

        self.fea_conv = nn.Conv2d(in_nc, nf, kernel_size=3, stride=down_scale, padding=1)
        resnet_blocks = [B.CondResNetBlock(nf, nf, nf, num_classes=num_classes, ada_ksize=ada_ksize,
                                           norm_type=norm_type, act_type=act_type) for _ in range(nb)]
        self.resnet_blocks = B.sequential(*resnet_blocks)
        self.LR_conv = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1)

        if norm_type == 'cond_adaptive_conv_res':
            self.cond_adaptive = AN.CondAdaptiveConvResNorm(nf, num_classes=num_classes)
        elif norm_type == "interp_adaptive_conv_res":
            self.cond_adaptive = AN.InterpAdaptiveResNorm(nf, ada_ksize)
        elif norm_type == "cond_instance":
            self.cond_adaptive = AN.CondInstanceNorm2d(nf, num_classes=num_classes)
        elif norm_type == "cond_transform_res":
            self.cond_adaptive = AN.CondResTransformer(nf, ada_ksize, num_classes=num_classes)

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)

        if down_scale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]

        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.upsample = B.sequential(*upsampler, HR_conv0, HR_conv1)

    def forward(self, x, y):
        # the first feature extraction
        fea = self.fea_conv(x)
        fea1, _ = self.resnet_blocks((fea, y))
        fea2 = self.LR_conv(fea1)
        fea3 = self.cond_adaptive(fea2, y)
        # res
        out = self.upsample(fea3 + fea)
        return out


class AdaptiveDenoiseResNet(nn.Module):
    """
    jingwen's addition
    adabn
    """
    def __init__(self, in_nc, nf, nb, upscale=1, res_scale=1, down_scale=2):
        super(AdaptiveDenoiseResNet, self).__init__()

        self.fea_conv = nn.Conv2d(in_nc, nf, kernel_size=3, stride=down_scale, padding=1)
        resnet_blocks = [B.AdaptiveResNetBlock(nf, nf, nf, res_scale=res_scale) for _ in range(nb)]
        self.resnet_blocks = B.sequential(*resnet_blocks)
        self.LR_conv = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1)
        self.batch_norm = nn.BatchNorm2d(nf, affine=True, track_running_stats=True, momentum=0)

    def forward(self, x):
        fea_list = [self.fea_conv(data.unsqueeze_(0)) for data in x]
        fea_resblock_list = self.resnet_blocks(fea_list)
        fea_LR_list = [self.LR_conv(fea) for fea in fea_resblock_list]
        fea_mean, fea_var = B.computing_mean_variance(fea_LR_list)

        batch_norm_dict = self.batch_norm.state_dict()
        batch_norm_dict['running_mean'] = fea_mean
        batch_norm_dict['running_var'] = fea_var
        self.batch_norm.load_state_dict(batch_norm_dict)
        return None



class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
        super(RRDBNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
        rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
            norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        x = self.model(x)
        return x


####################
# Discriminator
####################


# VGG style Discriminator with input size 128*128
class Discriminator_VGG_128(nn.Module):
    def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
        super(Discriminator_VGG_128, self).__init__()
        # features
        # hxw, c
        # 128, 64
        conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
            mode=mode)
        conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 64, 64
        conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 32, 128
        conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 16, 256
        conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 8, 512
        conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 4, 512
        self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
            conv9)

        # classifier
        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


# VGG style Discriminator with input size 128*128, Spectral Normalization
class Discriminator_VGG_128_SN(nn.Module):
    def __init__(self):
        super(Discriminator_VGG_128_SN, self).__init__()
        # features
        # hxw, c
        # 128, 64
        self.lrelu = nn.LeakyReLU(0.2, True)

        self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
        self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
        # 64, 64
        self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
        self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
        # 32, 128
        self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
        self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
        # 16, 256
        self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
        self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
        # 8, 512
        self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
        self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
        # 4, 512

        # classifier
        self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
        self.linear1 = SN.spectral_norm(nn.Linear(100, 1))

    def forward(self, x):
        x = self.lrelu(self.conv0(x))
        x = self.lrelu(self.conv1(x))
        x = self.lrelu(self.conv2(x))
        x = self.lrelu(self.conv3(x))
        x = self.lrelu(self.conv4(x))
        x = self.lrelu(self.conv5(x))
        x = self.lrelu(self.conv6(x))
        x = self.lrelu(self.conv7(x))
        x = self.lrelu(self.conv8(x))
        x = self.lrelu(self.conv9(x))
        x = x.view(x.size(0), -1)
        x = self.lrelu(self.linear0(x))
        x = self.linear1(x)
        return x


class Discriminator_VGG_96(nn.Module):
    def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
        super(Discriminator_VGG_96, self).__init__()
        # features
        # hxw, c
        # 96, 64
        conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
            mode=mode)
        conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 48, 64
        conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 24, 128
        conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 12, 256
        conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 6, 512
        conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 3, 512
        self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
            conv9)

        # classifier
        self.classifier = nn.Sequential(
            nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


class Discriminator_VGG_192(nn.Module):
    def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
        super(Discriminator_VGG_192, self).__init__()
        # features
        # hxw, c
        # 192, 64
        conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
            mode=mode)
        conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 96, 64
        conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 48, 128
        conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 24, 256
        conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 12, 512
        conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 6, 512
        conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 3, 512
        self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
            conv9, conv10, conv11)

        # classifier
        self.classifier = nn.Sequential(
            nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


####################
# Perceptual Network
####################


# Assume input range is [0, 1]
class VGGFeatureExtractor(nn.Module):
    def __init__(self,
                 feature_layer=34,
                 use_bn=False,
                 use_input_norm=True,
                 device=torch.device('cpu')):
        super(VGGFeatureExtractor, self).__init__()
        if use_bn:
            model = torchvision.models.vgg19_bn(pretrained=True)
        else:
            model = torchvision.models.vgg19(pretrained=True)
        self.use_input_norm = use_input_norm
        if self.use_input_norm:
            mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
            # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
            std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
            # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
            self.register_buffer('mean', mean)
            self.register_buffer('std', std)
        self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
        # No need to BP to variable
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        if self.use_input_norm:
            x = (x - self.mean) / self.std
        output = self.features(x)
        return output


# Assume input range is [0, 1]
class ResNet101FeatureExtractor(nn.Module):
    def __init__(self, use_input_norm=True, device=torch.device('cpu')):
        super(ResNet101FeatureExtractor, self).__init__()
        model = torchvision.models.resnet101(pretrained=True)
        self.use_input_norm = use_input_norm
        if self.use_input_norm:
            mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
            # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
            std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
            # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
            self.register_buffer('mean', mean)
            self.register_buffer('std', std)
        self.features = nn.Sequential(*list(model.children())[:8])
        # No need to BP to variable
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        if self.use_input_norm:
            x = (x - self.mean) / self.std
        output = self.features(x)
        return output


class MINCNet(nn.Module):
    def __init__(self):
        super(MINCNet, self).__init__()
        self.ReLU = nn.ReLU(True)
        self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
        self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
        self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
        self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
        self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
        self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
        self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)

    def forward(self, x):
        out = self.ReLU(self.conv11(x))
        out = self.ReLU(self.conv12(out))
        out = self.maxpool1(out)
        out = self.ReLU(self.conv21(out))
        out = self.ReLU(self.conv22(out))
        out = self.maxpool2(out)
        out = self.ReLU(self.conv31(out))
        out = self.ReLU(self.conv32(out))
        out = self.ReLU(self.conv33(out))
        out = self.maxpool3(out)
        out = self.ReLU(self.conv41(out))
        out = self.ReLU(self.conv42(out))
        out = self.ReLU(self.conv43(out))
        out = self.maxpool4(out)
        out = self.ReLU(self.conv51(out))
        out = self.ReLU(self.conv52(out))
        out = self.conv53(out)
        return out


# Assume input range is [0, 1]
class MINCFeatureExtractor(nn.Module):
    def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
                device=torch.device('cpu')):
        super(MINCFeatureExtractor, self).__init__()

        self.features = MINCNet()
        self.features.load_state_dict(
            torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
        self.features.eval()
        # No need to BP to variable
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        output = self.features(x)
        return output

实验结果:

两者结果对比如上图所示。直观上,以noise level 和 LR contact 到一起的视觉效果好点,

先denoise后SR的级联网络效果如下图

而先noise estimation 后SR的级联网络的效果如下图

后者PSNR高0.3dB左右

猜你喜欢

转载自blog.csdn.net/gwplovekimi/article/details/85688031