基于pytorch的ESPCN

版权声明: https://blog.csdn.net/gwplovekimi/article/details/84622539

在之前的博客中已经介绍了ESPCN的原理了(学习笔记之——基于深度学习的图像超分辨率重构)本博文是对ESPCN进行实现。代码的框架仍然采用xintao前辈的代码,本博文会给出关键的实现部分

python train.py -opt options/train/train_sr.json

python test.py -opt options/test/test_sr.json

ESPCN的网络结构

代码

network.py

# Generator
def define_G(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']#hear decide which model, and thia para is in .json. if you add a new model, this part must be modified

    if which_model == 'sr_resnet':  # SRResNet
        netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

#############################################################################################################
    elif which_model=='fsrcnn':#FSRCNN
        netG=arch.FSRCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')
#############################################################################################################
#############################################################################################################
    elif which_model=='espcn':#ESPCN
        netG=arch.ESPCN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')
#############################################################################################################
#############################################################################################################
    elif which_model=='srresnet':#SRResNet, the Original version
        netG=arch.OSRRESNET(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')
#############################################################################################################

    elif which_model == 'sft_arch':  # SFT-GAN
        netG = sft_arch.SFT_Net()

    elif which_model == 'RRDB_net':  # RRDB,this is ESRGAN
        netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
            nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
            act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))

    if opt['is_train']:
        init_weights(netG, init_type='kaiming', scale=0.1)###the weight initing. you can change this to change the method of init_weight
    if gpu_ids:
        assert torch.cuda.is_available()
        netG = nn.DataParallel(netG)
    return netG

architecture.py

#######################################################################################################3
#EPSCN
class ESPCN(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, upscale=2, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):##play attention the upscales
        super(ESPCN,self).__init__()
        self.conv1=nn.Conv2d(in_channels=in_nc,out_channels=64,kernel_size=5,stride=1,padding=2)
        self.conv2=nn.Conv2d(in_channels=64,out_channels=32,kernel_size=3,stride=1,padding=1)
        self.conv3=nn.Conv2d(in_channels=32,out_channels=in_nc*(upscale ** 2),kernel_size=3,stride=1,padding=1)
        self.pixel_shuffle=nn.PixelShuffle(upscale)

 
 
    def forward(self, x):
        out=F.tanh(self.conv1(x))
        out=F.tanh(self.conv2(out))
        out=F.sigmoid(self.pixel_shuffle(self.conv3(out)))
 
        return out

setting

{
  "name": "espcn_x4"//"001_RRDB_PSNR_x4_DIV2K" //  please remove "debug_" during training or tensorboard wounld not work
  , "use_tb_logger": true
  , "model":"sr"
  , "scale": 4
  , "gpu_ids": [4]

  , "datasets": {
    "train": {
      "name": "DIV2K800"
      , "mode": "LRHR"
      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub"
      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4"
      , "subset_file": null
      , "use_shuffle": true
      , "n_workers": 8
      , "batch_size": 16//how many samples in each iters
      , "HR_size": 192 // 128 | 192
      , "use_flip": true
      , "use_rot": true
    }
    , "val": {
      "name": "val_set5"
      , "mode": "LRHR"
      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5"
      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4"
    }
  }

  , "path": {
    "root": "/home/guanwp/BasicSR-master",
    "pretrain_model_G": null
     ,"experiments_root": "/home/guanwp/BasicSR-master/experiments/",
    "models": "/home/guanwp/BasicSR-master/experiments/espcn_x4/models",
    "log": "/home/guanwp/BasicSR-master/experiments/espcn_x4",
    "val_images": "/home/guanwp/BasicSR-master/experiments/espcn_x4/val_images"
  }

  , "network_G": {
    "which_model_G": "espcn"//"srresnet"//"sr_resnet"//"fsrcnn"//"sr_resnet" // RRDB_net | sr_resnet
    , "norm_type": null
    , "mode": "CNA"
    , "nf": 64//56//64
    , "nb": 23
    , "in_nc": 3
    , "out_nc": 3
    , "gc": 32
    , "group": 1
  }

  , "train": {
    "lr_G": 1e-3//1e-3//2e-4
    , "lr_scheme": "MultiStepLR"
    , "lr_steps": [200000,400000,600000,800000,1000000,1500000]
    , "lr_gamma": 0.5

    , "pixel_criterion": "l1"//"l1"//'l2'//huber//Cross
    , "pixel_weight": 1.0
    , "val_freq": 5e3

    , "manual_seed": 0
    , "niter": 2e6//2e6//1e6
  }

  , "logger": {
    "print_freq": 200
    , "save_checkpoint_freq": 5e3
  }
}

 

target

result

结果比原文的结果差一点点,感觉是由于学习率一开始太小带来的~

猜你喜欢

转载自blog.csdn.net/gwplovekimi/article/details/84622539