DFMN Code Interpretation

Table of contents

0. Environment configuration

1. Run the program

2. The idea of ​​​​reading code

1)model.py

!! About Inheritance

 !! About network structure organization

!! about forward

2) Data preprocessing

3)train.py


0. Environment configuration

It's very simple, it prompts that the package xxx is missing, and pip install xxx is fine.

1. Run the program

Start running from DFNI/train.py, and the program can run normally without reporting an error.

2. The idea of ​​​​reading code

Because the main contribution of DFMN (renamed DFNI) is the design of the network structure, therefore, the steps I look at the code are:

1) model.py 2) train.py 3) Data preprocessing related code files

1)model.py

!! About Inheritance

Because it is a self-designed network structure, it inherits nn.Module, and the initialization code is as follows.

You can see the defined DFNI network, which inherits from nn.Module . There are two places to pay attention to: 

  1. class DFNI(nn.Module) The parameter of the subclass should write the name of the parent class
  2. super(DFNI, self).__init__() When the subclass is initialized, it must inherit all the methods and properties of the parent class

Attributes and methods that are different from the parent class need to be redefined for coverage. For example: self.firstPart, self.midPart1.

 
import torch
import torch.nn as nn


class DFNI(nn.Module):
    def __init__(self, upscale_factor):
        super(DFNI, self).__init__()      
        self.firstPart = nn.Sequential(
            ####
        )
        self.midPart1 = nn.Sequential(
            ####
        )

        self.midPart2 = nn.Sequential(
            ####
        )
        #
        # for p in self.parameters():
        #     p.requires_grad = False

        self.finalPart = nn.Sequential(
            ###
        )
        self.con1 = nn.Conv2d(1, 8, kernel_size=7, padding=7 // 2)
        self.con2 = nn.Conv2d(8, 16, kernel_size=5, padding=5 // 2)
        self.con3 = nn.Conv2d(16, 32, kernel_size=3, padding=3 // 2)
        self.con4 = nn.Conv2d(32, 64, kernel_size=1)
        self.lrelu = nn.LeakyReLU()

    def forward(self, x, xx):     
        #####
        #####
 

 !! About network structure organization

Pack a bunch of convolution and Relu: The most common components of the network are convolution and Relu, a module composed of a bunch of convolution and nonlinear truncation functions, generally packaged into a module with nn.Sequential() , and recreated The name is convenient for subsequent calls in forward.

 The parameter setting of nn.Conv2d: nn.Conv2d(1, 8, kernel_size=7, padding=7 // 2), here, 1 is the number of input channels, 8 is the number of output channels, 7 is the size of the convolution kernel, The default step size is 1, filled with 0, the size is 7 // 2, and the padding size is set to 1/2 of the kernel size to ensure that the image size remains unchanged after convolution.

nn.Conv2d(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True))

in_channel: the number of channels of the input data, for example, the number of channels
of the RGB image is 3; out_channel: the number of channels of the output data, which is adjusted according to the model;
kennel_size: the size of the convolution kernel, which can be int or tuple; kennel_size=2, means volume Product size (2,2), kennel_size=(2,3), means convolution size (2,3), that is, non-square convolution
stride: step size, the default is 1, similar to kennel_size, stride=2, means The step size is 2 for up, down, left, and right scanning, stride=(2,3), the left and right scanning step is 2, and the up and down is 3;
padding: zero padding

self.firstPart = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7, padding=7 // 2),
            nn.LeakyReLU(inplace=True),   #???LeakyReLU
            nn.Conv2d(8, 16, kernel_size=5, padding=5 // 2),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=3, padding=3 // 2),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=1),
            nn.LeakyReLU(inplace=True)
        )

 nn.LeakyReLU(inplace=True)

ReLU is to set all negative values ​​to zero. On the contrary, Leaky ReLU is to give all negative values ​​a non-zero slope. 

!! about forward

It is enough to link the defined modules according to the data processing flow, and the output of the previous module is the input of the latter module.

    def forward(self, x, xx):    #  xx 咋来的?
        res = x
        # x = self.firstPart(x)
        x1 = self.lrelu(self.con1(x))
        x2 = self.lrelu(self.con2(x1))
        x3 = self.lrelu(self.con3(x2))
        x4 = self.lrelu(self.con4(x3))
        x = torch.cat((x1, x2, x3, x4), dim=1)  # dim 为0、1、2 分别表示增加行、增加列、增加厚度三个方向  ??为什么dim 是1
        b1 = self.midPart1(x)
        b2 = self.midPart2(x)
        x = torch.cat((b1, b2), dim=1)
        x = torch.add(res, x) # 相加和cat有啥区别?在网络设计的维度上是如何设置的?
        x = self.finalPart(x)
        x = torch.add(x, xx)
        return x

During forward propagation, because there are cascade modules, parallel modules, and residual modules, it is confusing how to set the number of channels in each layer? The following is the answer given by the author of the code.

When x = torch.cat((x1, x2, x3, x4), dim=1) splicing, why is dim 1? The model input is a four-dimensional tensor, corresponding to a, the first is equivalent to the number, which is 1; the second is the number of channels, the third is the row, and the fourth is the column. That is, dim=1 realizes channel number splicing.
x = torch.add(res, x) The number of channels of res will be copied and expanded according to the number of channels of x, such as 1->65, and then added

In the early stage of network design, how to track the dimension of each layer output data? Make sure that the parameters set for each layer are correct, especially when there is splicing? After the forward of the model is written, use random tensors as input and call the model to test it; if there is a problem with the design size, you can use print(xx.shape) in the model to observe it.

#写在模型定义外
model = DFNI(4)
input1= torch.randn(1, 1, 175, 63)
input2 = torch.randn(1, 1, 700, 252)
model.load_state_dict(torch.load('DFNI_4.pt',map_location='cpu'))
out = model(input1, input2)

#写在模型内
print(x4.shape)

2) Data preprocessing

The data of the paper is taken from open source data. The downsampling process is regular downsampling, and irregular ones will introduce noise.

The data of the model includes: input (downsampling data, interpolation results achieved by traditional methods <run inputdataGet>), output: original data.

We need to prepare the training data set and the test data set. In Pytorch, two classes, Dataset and DataLoader, are needed to read the data set. The Dataset is responsible for reading the data. The read content is each data and its corresponding label. ;DataLoader is responsible for packaging the data read by Dataset, and then sending them to the neural network in batches.

In the custom dataset, the key is to convert the data type to Dataset, so that DataLoader can be called.

In this example, the type conversion from npy to Dataset is implemented.

# 定义了DatasetFromFolder,继承自Dataset,目的:将自定义的数据转为Dataset类
class DatasetFromFolder(Dataset):
    def __init__(self, input1, input2, target):
        super(DatasetFromFolder, self).__init__()
        self.input1 = input1
        self.input2 = input2
        self.target = target

    def __getitem__(self, index):
        return self.input1[index], self.input2[index], self.target[index]

    def __len__(self):
        return len(self.target)


# input1, input2, target均来自.npy 文件,是一个npy数据转为Dataset的范例
input1, input2, target = dataGet_2(num)
input1 = input1.astype(np.float32)  #变化数组类型
input2 = input2.astype(np.float32)
target = target.astype(np.float32)


trainSet = DatasetFromFolder(input1, input2, target)
trainDataLoader = DataLoader(dataset=trainSet)

3)train.py

The experiment in this paper, due to the small amount of data, did not consider the verification set (??? But I am not sure about the credibility of the experimental results), maybe the final residual is added (join the traditional interpolation training results for training) , which can ensure that the experiment has a better effect.

Standard training, verification, and test template reference, very well written! Pytorch model training and model verification_MoxiMoses' blog-CSDN blog_pytorch training model

Guess you like

Origin blog.csdn.net/u014655960/article/details/128537669