实验——基于SRResNet的denoise

版权声明: https://blog.csdn.net/gwplovekimi/article/details/85059863

本博文是采用SRResNet来实现图像去噪的实验记录

python train.py -opt options/train/train_ne.json.json

setting(不知道为何,用我自己实现的srresnet需要两块GPU才可以跑起来,但是用xintao的可以一块就ok了,但都是跑得很慢~~~所以最终还是决定两块卡一起跑吧哎)

{
  "name": "n_g_srresnet" //"001_RRDB_PSNR_x4_DIV2K" //  please remove "debug_" during training or tensorboard wounld not work
  ,
  "use_tb_logger": true,
  "model": "sr",
  //"crop_scale": 0,
   "scale": 1//it must be 1
  ,
  "gpu_ids": [4],
  "datasets": {
    "train": {
      "name": "DIV2K800",
      "mode": "LRHR" //it must be this, and the detail would be shown in LRHR_dataset.py
      //, "noise_get": true///////////////////////////////////////////////////////////////////////
      ,
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub" ///////////must be sub
      ,
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_n",
      "subset_file": null,
      "use_shuffle": true,
      "n_workers": 8,
      "batch_size": 16//32 //how many samples in each iters
      ,
      "HR_size": 192 // 128 | 192
      ,
      "use_flip": false //true//////////////////////////
      ,
      "use_rot": false //true////////////////////////
    },
    "val": {
      "name": "val_set5",
      "mode": "LRHR",
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5",
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_n"
      //, "noise_get": true///////////////////////////////////////////////////////////////////////this is important
    }
  },
  "path": {
    "root": "/home/guanwp/BasicSR-master/",
    "pretrain_model_G": null,
    "experiments_root": "/home/guanwp/BasicSR-master/experiments/",
    "models": "/home/guanwp/BasicSR-master/experiments/n_g_srresnet/models",
    "log": "/home/guanwp/BasicSR-master/experiments/n_g_srresnet",
    "val_images": "/home/guanwp/BasicSR-master/experiments/n_g_srresnet/val_images"
  },
  "network_G": {
    "which_model_G": "srresnet"//"noise_estimation" //"espcn"//"srresnet"//"sr_resnet"//"fsrcnn"//"sr_resnet" // RRDB_net | sr_resnet
    ,
    "norm_type": null,
    "mode": "CNA",
    "nf": 64 //56//64
    ,
    "nb": 23,
    "in_nc": 3,
    "out_nc": 3,
    "gc": 32,
    "group": 1
  },
  "train": {
    "lr_G": 1e-3//8e-4 //1e-3//2e-4
    ,
    "lr_scheme": "MultiStepLR",
    "lr_steps": [100000,200000,300000,400000,600000],
    "lr_gamma": 0.5,
    "pixel_criterion": "l2" //"l2_tv"//"l1"//'l2'//huber//Cross   //should be MSE LOSS
    ,
    "pixel_weight": 1.0,
    "val_freq": 1e3,
    "manual_seed": 0,
    "niter": 8e5 //2e6//1e6
  },
  "logger": {
    "print_freq": 200,
    "save_checkpoint_freq": 1e3
  }
}

网络结构

##########################################################################################################
#SRResNet, the Original version
#define the residual block
class O_Residual_Block(nn.Module):
    def __init__(self):
        super(O_Residual_Block,self).__init__()

        self.conv1=nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1=nn.BatchNorm2d(64, affine=True)
        self.prelu=nn.PReLU()
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2=nn.BatchNorm2d(64, affine=True)

    def forward(self, x):
        identity_data1 = x
        output = self.prelu(self.bn1(self.conv1(x)))
        output = self.bn2(self.conv2(output))
        output = torch.add(output,identity_data1)
        return output 

##############################################
class OSRRESNET(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, upscale=2, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):##play attention the upscales
        super(OSRRESNET,self).__init__()

        self.conv_input=nn.Conv2d(in_channels=in_nc,out_channels=nf,kernel_size=9,stride=1,padding=4,bias=False)
        self.prelu=nn.PReLU()

        self.residual=self.make_layer(O_Residual_Block,16)

        self.conv_mid = nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_mid=nn.BatchNorm2d(64, affine=True)


        self.upscale4x = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2),
            nn.PReLU(),
        )

        self.conv_output = nn.Conv2d(in_channels=nf, out_channels=out_nc, kernel_size=9, stride=1, padding=4, bias=False)


    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)
 
    def forward(self,x):
        out = self.prelu(self.conv_input(x))
        residual = out
        out = self.residual(out)
        out = self.bn_mid(self.conv_mid(out))
        out = torch.add(out,residual)
        #out = self.upscale4x(out)
        out = self.conv_output(out)
        return out

 
 ##########################################################################################################################################################

然后会发现网络学到一半突然跑飞。。。。。。。。。。

网络改为

##########################################################################################################
#SRResNet, the Original version
#define the residual block
class DN_Residual_Block(nn.Module):
    def __init__(self):
        super(DN_Residual_Block,self).__init__()

        self.conv1=nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1=nn.BatchNorm2d(64, affine=True)
        self.prelu=nn.PReLU()
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2=nn.BatchNorm2d(64, affine=True)

    def forward(self, x):
        identity_data1 = x
        output = self.prelu(self.bn1(self.conv1(x)))
        output = self.bn2(self.conv2(output))
        output = torch.add(output,identity_data1)
        return output 

##############################################
class DN_ResNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, upscale=2, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):##play attention the upscales
        super(DN_ResNet,self).__init__()

        self.conv_input=nn.Conv2d(in_channels=in_nc,out_channels=nf,kernel_size=9,stride=1,padding=4,bias=False)
        self.prelu=nn.PReLU()

        self.residual=self.make_layer(O_Residual_Block,16)

        self.conv_mid = nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_mid=nn.BatchNorm2d(64, affine=True)


        self.upscale4x = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2),
            nn.PReLU(),
        )

        self.conv_output = nn.Conv2d(in_channels=nf, out_channels=out_nc, kernel_size=9, stride=1, padding=4, bias=False)


    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)
 
    def forward(self,x):
        out = self.prelu(self.conv_input(x))
        residual = out
        out = self.residual(out)
        out = self.bn_mid(self.conv_mid(out))
        out = torch.add(out,residual)
        #out = self.upscale4x(out)
        out = self.conv_output(out)
        return out

 
 ##########################################################################################################################################################

settig改为:

{
  "name": "srresnet" //"001_RRDB_PSNR_x4_DIV2K" //  please remove "debug_" during training or tensorboard wounld not work
  ,
  "use_tb_logger": true,
  "model": "sr",
  //"crop_scale": 0,
   "scale": 1//it must be 1
  ,
  "gpu_ids": [4,5],
  "datasets": {
    "train": {
      "name": "DIV2K800",
      "mode": "LRHR" //it must be this, and the detail would be shown in LRHR_dataset.py
      //, "noise_get": true///////////////////////////////////////////////////////////////////////
      ,
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub" ///////////must be sub
      ,
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_n",
      "subset_file": null,
      "use_shuffle": true,
      "n_workers": 8,
      "batch_size": 16//32 //how many samples in each iters
      ,
      "HR_size": 128 // 128 | 192
      ,
      "use_flip": false //true//////////////////////////
      ,
      "use_rot": false //true////////////////////////
    },
    "val": {
      "name": "val_set5",
      "mode": "LRHR",
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5",
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_n"
      //, "noise_get": true///////////////////////////////////////////////////////////////////////this is important
    }
  },
  "path": {
    "root": "/home/guanwp/BasicSR-master/",
    "pretrain_model_G": null,
    "experiments_root": "/home/guanwp/BasicSR-master/experiments/",
    "models": "/home/guanwp/BasicSR-master/experiments/srresnet/models",
    "log": "/home/guanwp/BasicSR-master/experiments/srresnet",
    "val_images": "/home/guanwp/BasicSR-master/experiments/srresnet/val_images"
  },
  "network_G": {
    "which_model_G": "dn_srresnet"//"noise_estimation" //"espcn"//"srresnet"//"sr_resnet"//"fsrcnn"//"sr_resnet" // RRDB_net | sr_resnet
    ,
    "norm_type": null,
    "mode": "CNA",
    "nf": 64 //56//64
    ,
    "nb": 16,//////////////number of residual block
    "in_nc": 3,
    "out_nc": 3,
    "gc": 32,
    "group": 1
  },
  "train": {
    "lr_G": 6e-4//8e-4 //1e-3//2e-4
    ,
    "lr_scheme": "MultiStepLR",
    "lr_steps": [200000,300000,400000,600000],
    "lr_gamma": 0.5,
    "pixel_criterion": "l1" //"l2_tv"//"l1"//'l2'//huber//Cross   //should be MSE LOSS
    ,
    "pixel_weight": 1.0,
    "val_freq": 1e3,
    "manual_seed": 0,
    "niter": 8e5 //2e6//1e6
  },
  "logger": {
    "print_freq": 200,
    "save_checkpoint_freq": 1e3
  }
}

猜你喜欢

转载自blog.csdn.net/gwplovekimi/article/details/85059863