DnCNN代码学习—main_train.py

                                      DnCNN代码学习—main_train.py

一、源代码+注释

# -*- coding: utf-8 -*-

# PyTorch 0.4.1, https://pytorch.org/docs/stable/index.html

# =============================================================================
#  @article{zhang2017beyond,
#    title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
#    author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
#    journal={IEEE Transactions on Image Processing},
#    year={2017},
#    volume={26}, 
#    number={7}, 
#    pages={3142-3155}, 
#  }
# by Kai Zhang (08/2018)
# [email protected]
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================

# run this to train the model

# =============================================================================
# For batch normalization layer, momentum should be a value from [0.1, 1] rather than the default 0.1. 
# The Gaussian noise output helps to stablize the batch normalization, thus a large momentum (e.g., 0.95) is preferred.
# =============================================================================

import argparse   # python的参数解析argparse模块
import re  #Python正则表达式re模块

#datetime模块提供了简单和复杂的方式用于操纵日期和时间的类。
import os, glob, datetime, time   #文件名操作模块glob 
import numpy as np  #NumPy(Numerical Python) 是 Python 语言的一个扩展程序库,支持大量的维度数组与矩阵运算,此外也针对数组运算提供大量的数学函数库。
import torch   #包 torch 包含了多维张量的数据结构以及基于其上的多种数学操作。
#torch.nn,nn就是neural network的缩写,这是一个专门为深度学习而设计的模块。torch.nn的核心数据结构是Module,这是一个抽象的概念,
#既可以表示神经网络中的某个层(layer),也可以表示一个包含很多层的神经网络。在实际使用中,最常见的做法是继承nn.Module,撰写自己的网络/层。
import torch.nn as nn    

from torch.nn.modules.loss import _Loss  #没一个损失函数作为一个类,不继承自_Loss类,而_Loss类又继承自Module类
import torch.nn.init as init  #初始化
from torch.utils.data import DataLoader  #DataLoader类的作用就是实现数据以什么方式输入到什么网络中。
import torch.optim as optim   #torch.optim是一个实现了各种优化算法的库。 
from torch.optim.lr_scheduler import MultiStepLR   #torch.optim.lr_scheduler提供了几种方法来根据迭代的数量来调整学习率  #按需调整学习率 MultiStepLR按设定的间隔调整学习率。
import data_generator as dg  #导入处理数据文件,命名dg,数据生成器
from data_generator import DenoisingDataset  #从data_generator文件导入 DenoisingDataset类

#创建一个解析器
#使用 argparse 的第一步是创建一个 ArgumentParser 对象
parser = argparse.ArgumentParser(description='PyTorch DnCNN')
#参加参数
#给一个 ArgumentParser 添加程序参数信息是通过调用 add_argument() 方法完成的。通常,这些调用指定 ArgumentParser 如何获取命令行字符串并将其转换为对象。
#模型  字符串   默认为DnCNN
parser.add_argument('--model', default='DnCNN', type=str, help='choose a type of model')
#批量大小  整型   默认大小128 
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
#训练数据   字符串  默认 data/Train400  路径 
parser.add_argument('--train_data', default='data/Train400', type=str, help='path of train data')
#噪声水平  整型  默认25 
parser.add_argument('--sigma', default=25, type=int, help='noise level')
# epoch 整型  默认180
parser.add_argument('--epoch', default=180, type=int, help='number of train epoches')
#学习率  float 0.001  adam优化算法
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam')
#解析参数
#ArgumentParser 通过 parse_args() 方法解析参数。它将检查命令行,把每个参数转换为适当的类型然后调用相应的操作。
args = parser.parse_args()

#rgparse解析命令行参数来传递参数
batch_size = args.batch_size   #batch_size=128
cuda = torch.cuda.is_available()  #torch.cuda.is_available() cuda是否可用;
n_epoch = args.epoch   #n_epoch= 180
sigma = args.sigma     #sigma  = 25
#os.path.join():  将多个路径组合后返回args.model = DNCNN  str(sigma)=25
#组合之后的路径为models/DNCNN_sigma25
save_dir = os.path.join('models', args.model+'_' + 'sigma' + str(sigma))

#判断路径是否存在 如果不存在 则新建路径
#os.mkdir()创建路径中的最后一级目录,即:只创建DNCNN_sigma25目录,而如果之前的目录不存在并且也需要创建的话,就会报错。
#os.makedirs()创建多层目录,即:models,DNCNN_sigma25如果都不存在的话,会自动创建,
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

#声明一个类,并继承自nn.Module
class DnCNN(nn.Module):
    #定义构造函数
    #构建网络最开始写一个class,然后def _init_(输入的量),然后super(DnCNN,self).__init__()这三句是程式化的存在,
    #初始化
    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        ##初始化方法使用父类的方法即可,super这里指的就是nn.Module这个基类,第一个参数是自己创建的类名
        super(DnCNN, self).__init__()
        #定义自己的网络
        
        kernel_size = 3  #卷积核的大小  3*3
        padding = 1  ##padding表示的是在图片周围填充0的多少,padding=0表示不填充,padding=1四周都填充1维
        layers = [] 
        #四个参数 输入的通道  输出的通道  卷积核大小  padding
        #构建一个输入通道为channels,输出通道为64,卷积核大小为3*3,四周进行1个像素点的零填充的conv1层  #bias如果bias=True,添加偏置
        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        ##增加网络的非线性——激活函数nn.ReLU(True)  在卷积层(或BN层)之后,池化层之前,添加激活函数
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            ##构建卷积层
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            #加快收敛速度一一批标准化层nn.BatchNorm2d()  输入通道为64的BN层 与卷积层输出通道数64对应
            #eps为保证数值稳定性(分母不能趋近或取0),给分母加上的值。默认为1e-4
            #momentum: 动态均值和动态方差所使用的动量。默认为0.1
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
            #增加网络的非线性——激活函数nn.ReLU(True)  在卷积层(或BN层)之后,池化层之前,添加激活函数
            layers.append(nn.ReLU(inplace=True))
        #构建卷积层
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        #利用nn.Sequential()按顺序构建网络
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()  #调用初始化权重函数

    #定义自己的前向传播函数
    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y-out

    def _initialize_weights(self):
        for m in self.modules():
            ## 使用isinstance来判断m属于什么类型【卷积操作】
            if isinstance(m, nn.Conv2d):
                #正交初始化(Orthogonal Initialization)主要用以解决深度网络下的梯度消失、梯度爆炸问题
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    #init.constant_常数初始化
                    init.constant_(m.bias, 0)
            ## 使用isinstance来判断m属于什么类型【批量归一化操作】
            elif isinstance(m, nn.BatchNorm2d):
                #init.constant_常数初始化
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

#定义损失函数类
class sum_squared_error(_Loss):  # PyTorch 0.4.1
    """
    Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
    The backward is defined as: input-target
    """
    
    def __init__(self, size_average=None, reduce=None, reduction='sum'):
        super(sum_squared_error, self).__init__(size_average, reduce, reduction)
    
    #MSELoss  计算input和target之差的平方
    #reduce(bool)- 返回值是否为标量,默认为True size_average(bool)- 当reduce=True时有效。为True时,返回的loss为平均值;为False时,返回的各样本的loss之和。
    def forward(self, input, target):
        # return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
        return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)


#看的不是特别懂,返回值要么最大的,要么0
def findLastCheckpoint(save_dir):
    #返回所有匹配的文件路径列表。它只有一个参数pathname,定义了文件路径匹配规则,这里可以是绝对路径,也可以是相对路径。
    file_list = glob.glob(os.path.join(save_dir, 'model_*.pth'))
    if file_list:
        epochs_exist = []
        for file_ in file_list:
            #re.findall  的简单用法(返回string中所有与pattern相匹配的全部字串,返回形式为数组)
            result = re.findall(".*model_(.*).pth.*", file_)
            epochs_exist.append(int(result[0]))  #append() 方法用于在列表末尾添加新的对象。
        initial_epoch = max(epochs_exist)
    else:
        initial_epoch = 0
    return initial_epoch

#strftime()方法使用日期,时间或日期时间对象返回表示日期和时间的字符串
def log(*args, **kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)


if __name__ == '__main__':
    # model selection
    print('===> Building model')
    model = DnCNN()
    
    initial_epoch = findLastCheckpoint(save_dir=save_dir)  # load the last model matconvnet
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        # model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)))
        #加载模型
        model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))
    #训练模型时会在前面加上
    model.train()
    # criterion = nn.MSELoss(reduction = 'sum')  # PyTorch 0.4.1  #返回的各样本的loss之和
    criterion = sum_squared_error()
    if cuda:
        model = model.cuda()
         # device_ids = [0]
         # model = nn.DataParallel(model, device_ids=device_ids).cuda()
         # criterion = criterion.cuda()
    ## Optimizer  采用Adam算法优化,模型参数  学习率
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    #milestones为一个数组,如 [50,70]. gamma为0.1 倍数。如果learning rate开始为0.01 ,则当epoch为50时变为0.001,epoch 为70 时变为0.0001。当last_epoch=-1,设定为初始lr
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)  # learning rates
    for epoch in range(initial_epoch, n_epoch):  #n——epoch = 180 

        scheduler.step(epoch)  # step to the learning rate in this epcoh
        xs = dg.datagenerator(data_dir=args.train_data)  #调用数据生成器函数
        xs = xs.astype('float32')/255.0   #对数据进行处理,位于【0 1】
        #torch.from_numpy将numpy.ndarray 转换为pytorch的 Tensor。  transpose多维数组转置
        xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))  # tensor of the clean patches, N X C X H X W
        #加噪声函数
        DDataset = DenoisingDataset(xs, sigma)
        #dataset:(数据类型 dataset)
        # num_workers:工作者数量,默认是0。使用多少个子进程来导入数据  
        #drop_last:丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。
        #shuffle洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。
        DLoader = DataLoader(dataset=DDataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True)
        epoch_loss = 0  #初始化
        start_time = time.time()  #time.time() 返回当前时间的时间戳

        for n_count, batch_yx in enumerate(DLoader):  #enumerate() 函数用于将一个可遍历的数据对象
                optimizer.zero_grad()  #optimizer.zero_grad()意思是把梯度置零,也就是把loss关于weight的导数变成0.
                if cuda:
                    batch_x, batch_y = batch_yx[1].cuda(), batch_yx[0].cuda()
                loss = criterion(model(batch_y), batch_x)  #计算损失值
                epoch_loss += loss.item()  #对损失值求和
                loss.backward()  #反向传播
                optimizer.step()  #adam优化
                #每十张输出epoch  n_count  xs.size(0)//batch_size  loss.item()/batch_size)
                #不清楚xs.size(0)//batch_size是什么意思。1862定值
                if n_count % 10 == 0:
                    print('%4d %4d / %4d loss = %2.4f' % (epoch+1, n_count, xs.size(0)//batch_size, loss.item()/batch_size))
        elapsed_time = time.time() - start_time  #当前时间-开始时间
        
       
        log('epcoh = %4d , loss = %4.4f , time = %4.2f s' % (epoch+1, epoch_loss/n_count, elapsed_time))
        #numpy.savetxt(fname,X):第一个参数为文件名,第二个参数为需要存的数组(一维或者二维)第三个参数是保存的数据格式
        #hstack 和 vstack这两个函数分别用于在水平方向和竖直方向增加数据
        np.savetxt('train_result.txt', np.hstack((epoch+1, epoch_loss/n_count, elapsed_time)), fmt='%2.4f')
        # torch.save(model.state_dict(), os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))
        #保存模型
        torch.save(model, os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))

二、查找资料链接

1、argparse --- 命令行选项、参数和子命令解析器

#coding:UTF-8
import argparse   # python的参数解析argparse模块
#创建一个解析器
#使用 argparse 的第一步是创建一个 ArgumentParser 对象
parser = argparse.ArgumentParser(description='PyTorch DnCNN')
#参加参数
#给一个 ArgumentParser 添加程序参数信息是通过调用 add_argument() 方法完成的。通常,这些调用指定 ArgumentParser 如何获取命令行字符串并将其转换为对象。
#模型  字符串   默认为DnCNN
parser.add_argument('--model', default='DnCNN', type=str, help='choose a type of model')
#批量大小  整型   默认大小128 
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
#训练数据   字符串  默认 data/Train400  路径 
parser.add_argument('--train_data', default='data/Train400', type=str, help='path of train data')
#噪声水平  整型  默认25 
parser.add_argument('--sigma', default=25, type=int, help='noise level')
# epoch 整型  默认180
parser.add_argument('--epoch', default=180, type=int, help='number of train epoches')
#学习率  float 0.001  adam优化算法
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam')
#解析参数
#ArgumentParser 通过 parse_args() 方法解析参数。它将检查命令行,把每个参数转换为适当的类型然后调用相应的操作。
args = parser.parse_args()

#rgparse解析命令行参数来传递参数
batch_size = args.batch_size   #batch_size=128
n_epoch = args.epoch   #n_epoch= 180
sigma = args.sigma     #sigma  = 25

print('batch_size=%4d n_epoch=%4d sigma=%4d ' % (batch_size, n_epoch ,sigma))

对argparse模块进行测试,输出结果显示正确。模块主要对参数进行初始化,有点类似C语言中的define的作用。之后修改参数,只需修该argparsr中的参数即可。

2、re --- 正则表达式操作 — Python 3.7.4 文档

Python正则表达式指南 - AstralWind - 博客园

Python 标准库——os、glob模块 - 温柔一cai刀 - CSDN博客

python glob模块 - 火星大熊猫 - CSDN博客

Python标准库笔记(3) — datetime模块 - j_hao104 - 博客园

python datetime - 刘江的python教程

python: time模块、datetime模块 - yumu - CSDN博客

NumPy 教程 | 菜鸟教程

火炬 - PyTorch中文文档

torch.nn 神经网络工具 | AI初学者教程

pytorch loss function 总结 - 张小彬的专栏 - CSDN博客

pytorch教程之损失函数详解——多种定义损失函数的方法 - 神评网

torch.nn.init - PyTorch中文文档

PyTorch 中的数据类型 torch.utils.data.DataLoader - rogerfang的博客 - CSDN博客

torch.utils.data - PyTorch中文文档

torch.optim - PyTorch中文文档

pytorch中的学习率调整函数 - 慢行厚积 - 博客园

[pytorch中文文档] torch.optim - pytorch中文网

PyTorch学习之六个学习率调整策略 - mingo_敏 - CSDN博客

torch.cuda.is_available - daoer_sofu的专栏 - CSDN博客

python路径拼接os.path.join()函数完全教程 - 开贰锤 - CSDN博客

os.mkdir()和os.mkdirs()的区别和用法 - 算法小白 - CSDN博客

pytorch(二)--batch normalization的理解 - tanglinjie的CSDN博客 - CSDN博客

PyTorch参数初始化和Finetune - 知乎

PyTorch 实现中的一些常用技巧-PyTorch 中文网

Pytorch.nn.conv2d 过程验证(单,多通道卷积过程) - 知乎

PyTorch 学习笔记(六):PyTorch的十七个损失函数 - spectre - CSDN博客

PyTorch实战指南 - 知乎

【Python】正则表达式 re.findall 用法 - YZXnuaa的博客 - CSDN博客

Python List append()方法 | 菜鸟教程

Python-基础-时间日期处理小结

Python strftime() - datetime to string

Pytorch的net.train 和 net.eval的使用 - Never-Giveup的博客 - CSDN博客

torch.optim.lr_scheduler.MultiStepLR - qq_41872630的博客 - CSDN博客

Python numpy.transpose 详解 - November、Chopin - CSDN博客

Pytorch(五)入门:DataLoader 和 Dataset - 嘿芝麻的树洞 - CSDN博客

Python time time()方法 | 菜鸟教程

Python enumerate() 函数 | 菜鸟教程

torch代码解析 为什么要使用optimizer.zero_grad() - scut_salmon的博客 - CSDN博客

pytorch学习笔记(1)-optimizer.step()和scheduler.step() - 攻城狮的自我修养 - CSDN博客

Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别 (Pytorch 代码讲解) - xiaoxifei的专栏 - CSDN博客

PyTorch 学习笔记(三):transforms的二十二个方法 - TensorSense的博客 - CSDN博客

python使用numpy读取、保存txt数据 - AManFromEarth的博客 - CSDN博客

Numpy 的一些基础操作必知必会-PyTorch 中文网

measure (measure) - Scikit image 中文开发手册 - 开发者手册 - 云+社区 - 腾讯云

pytorch学习(五)—图像的加载/读取方式 - 简书

matplotlib figure函数学习笔记 - 李啸林的专栏 - CSDN博客

matplotlib.pyplot.figure — Matplotlib 3.1.1 documentation

matplotlib模块数据可视化-图片处理 - sinat_36772813的博客 - CSDN博客

pytorch中图片显示问题 - lighting - CSDN博客

Matplotlib 教程 | 始终

imshow / matshow的插值 - Matplotlib 3.1.1文档

Matplotlib:给子图添加colorbar(颜色条或渐变色条) - 简书

PSNR和SSIM - 文森vincent - 博客园

Python os.listdir() 方法 | 菜鸟教程

Python中startswith和endswith的用法 - Fu4ng - CSDN博客

numpy.random.seed()的使用 - linzch3的博客 - CSDN博客

从np.random.normal()到正态分布的拟合 - Zhang's Wikipedia - CSDN博客

numpy.random.normal函数 - linyi_pk的博客 - CSDN博客

数据格式汇总及type, astype, dtype区别 - 机器学习-深度学习-图像处理-opencv-段子 - CSDN博客

我在读pyTorch文档(二) - aiqiu_gogogo的博客 - CSDN博客

(3条消息)np.reshape()和torch.view() - dspeia的博客 - CSDN博客

Torch张量的view方法有什么作用? - 纯净的天空

pytorch 正确的测试时间的代码 torch.cuda.synchronize() - u013548568的博客 - CSDN博客

PyTorch学习笔记(2)——变量类型(cpu/gpu) - g11d111的博客 - CSDN博客

os.path.splitext(“文件路径”) - 机器学习爱好者 - CSDN博客

numpy数组拼接:stack(),vstack(),hstack()函数使用总结 - Mao_Jonah的博客 - CSDN博客

NumPy 统计函数 | 菜鸟教程

OpenCV Python教程(1、图像的载入、显示和保存) - sunny2038的专栏 - CSDN博客

pytorch实现自由的数据读取-torch.utils.data的学习 - tsq292978891的博客 - CSDN博客

PyTorch—torch.utils.data.DataLoader 数据加载类 - wsp_1138886114的博客 - CSDN博客

Pytorch 04: Pytorch中数据加载---Dataset类和DataLoader类 - 一遍看不懂,我就再看一遍 - CSDN博客

torch.Tensor的4种乘法 - da_kao_la的博客 - CSDN博客

torch.mul() 和 torch.mm() 的区别 - Real_Brilliant的博客 - CSDN博客

PyTorch入门教程(1) - 知乎

pytorch torch张量 - pytorch中文网

PyTorch简明教程 - 李理的博客

x = x.view(x.size(0), -1) 的理解 - whut_ldz的博客 - CSDN博客

Data Augmentation--数据增强解决你有限的数据集 - chang_rj的博客 - CSDN博客

rot90--矩阵旋转 - qq_18343569的博客 - CSDN博客

矩阵的翻转与旋转()(另附代码) - 神评网

Python 中各种imread函数的区别与联系 - Mr. Chen - CSDN博客

opencv imread()方法第二个参数介绍 - qq_27278957的博客 - CSDN博客

位深度、色深的区别以及图片大小的计算 - cc65431362的专栏 - CSDN博客

numpy数据类型 - NumPy 中文文档

5 python numpy.expand_dims的用法 - hxshine的博客 - CSDN博客

numpy中的expand_dims函数 - qm5132的博客 - CSDN博客

(5条消息)numpy的delete删除数组整行和整列 - JamesShawn - CSDN博客

什么是判别模型(Discriminative Model)和生成模型(Generative Model) - zhaoyu106的博客 - CSDN博客

图像去噪算法简介 - InfantSorrow - 博客园

自然图像先验与图像复原 - zbwgycm的博客 - CSDN博客

机器学习概念篇:一文详解凸函数和凸优化,干货满满 - feilong_csdn的博客 - CSDN博客

【图像缩放】双立方(三次)卷积插值 - 程序生涯 - SegmentFault 思否

receptive field,即感受野 - coder - CSDN博客

图像质量评价标准学习笔记(1)-均方误差、峰值信噪比、结构相似性理论、多尺度结构相似性 - weixin_42769131的博客 - CSDN博客

Batch Normalization + Internal Covariate Shift(论文理解) - jason19966的博客 - CSDN博客

torch.nn - PyTorch中文文档

OpenCV图像的基本操作 · OpenCV-Python中文教程 · 看云

[Python] glob 模块(查找文件路径) - 简书

这 5 种计算机视觉技术,刷新你的世界观 | Laravel China 社区

卷积神经网络CNN:Tensorflow实现(以及对卷积特征的可视化) - TinyMind -专注人工智能的技术社区

train_data = train_data.transpose((0,3,1,2))ValueError:轴与数组不匹配·问题#4·ZFTurbo / KAGGLE_DISTRACTED_DRIVER

机器学习 | 王成飞博客

超越图像分类:更多应用深度学习的方法 - 知乎

分类: PyTorch | 从零开始的BLOG

发布了443 篇原创文章 · 获赞 656 · 访问量 60万+

猜你喜欢

转载自blog.csdn.net/LiuJiuXiaoShiTou/article/details/100170728