Custom code for Network Slim pruning

Custom code for Network Slim pruning

The recent project requires the final deployment of the detection algorithm. After the deployment, the algorithm reasoning speed is not very fast, and the model needs to be pruned.

The pruning algorithm I refer to is a classic pruning algorithm on 2017 ICCV: "Learning Efficient Convolutional Networks through Network Slimming". The principle of the algorithm is also very simple. The pre-knowledge that needs to be used is to understand the principle of Batch Normalization first. For the principle of BN, I have briefly analyzed the principle of Batch Normalization and gradient disappearance and gradient explosion in this article. What is needed You can take a look.

Since my detection algorithm is built on the framework of mmdet, I tried Microsoft's open source pruning tool nni, but it was unsuccessful. Finally, I planned to write the code for pruning by myself. The following is an idea and specific code implementation of my pruning.

1.1 Pruning ideas

According to the principles mentioned in the above reference papers, we need to prune the BN layer. Take resnet50 as an example. When we build a network, a BN layer is usually connected after a Conv layer, and the BN layer is connected to the next Conv layer in addition to the activation function layer. For convenience, we put the BN layer in front of it. The convolutional layer is denoted as Conv1, and the latter is denoted as Conv2. The relationship between them is: the number of output channels of Conv1 is consistent with the parameter dimension of the BN layer, and it must also be consistent with the number of input channels of Conv2. As shown in the figure below, here is an example of one layer of resnet50. The number of input channels of Conv1
insert image description here
is 256, the convolution kernel size is [1,1], and the number of output channels is 64; the output of Conv1 is used as the input of the BN layer, and the output after the BN layer processing is consistent with the number of input channels, both of which are 64; then the The output of the BN layer is used as the input of Conv2, and Conv2 will accept 64-channel input, and the output is also 64-channel.

Next, I will introduce the pruning method of BN, which is also very simple. The BN layer will learn two parameters for each channel of the input, respectively β \betaβ γ \gamma γ , corresponding to the following formula:
y 1 ← γ 1 x 1 ^ + β 1 ≡ BN γ 1 , β 1 ( x 1 ) y_{1} \leftarrow \gamma_1 \hat{x_{1}}+\beta_1 \equiv B N_{\gamma_1, \beta_1}\left(x_{1}\right)y1c1x1^+b1BNc1, b1(x1)
so the parameterγ \gammaγ can be seen as a weight to measure the importance of each channel. We can set a threshold, the channel will be deleted if it is lower than this threshold, and it will be kept if it is higher. As shown in the diagram below.
insert image description here

1.2 Code Implementation Ideas

After clarifying the principle, the next step is to write the code to realize this idea. If we want to prune a BN layer, we need to prune the two Conv layers directly connected to the BN layer at the same time.

First, load the entire model. The method I use here is to use the torch.load() function directly:

# 保存整个网络
torch.save(model, PATH)
# 加载整个模型
torch.load(PATH)

The printed model loaded in this way looks like this:
insert image description here
You can see that this is a structure of the model (I only took a part of resnet50), and then we print out its parameters to see:
insert image description here
the above are part of the parameter dimensions of the model. It can be seen that the parameters of the BN layer are 64-dimensional, which corresponds to the number of output channels of the previous Conv layer, and the BN layer has weight, bias, mean, and var, which are 4 parameters with dimensions, and we need to prune them at the same time .

As mentioned in the previous part of the principle, we take γ \gammaThe γ parameter is used as a weight to measure the importance of the channel, corresponding to the weight of the parameter part. Therefore, when we prune the BN layer, we must first set a threshold, and then compare each value of the weight with the threshold to obtain an index greater than the threshold:

def find_indice(module, thresh):  #module就是一个BN层
    gamma = module.weight.data
    mask = gamma > thresh
    indices = torch.nonzero(mask).view(-1)
    return indices

Then we use this index to prune the 4 parameters of the BN layer. In addition to parameter pruning, special attention should also be paid to pruning the structure, that is, modifying the number of channels corresponding to the BN layer on the model to the pruned dimension:

#对参数进行修剪
m.weight.data = m.weight.data[bn_dict["backbone.layer1.0.bn1"]]  #gamma
m.bias.data = m.bias.data[bn_dict["backbone.layer1.0.bn1"]]      #beta
m.running_mean.data = m.running_mean.data[bn_dict["backbone.layer1.0.bn1"]]
m.running_var.data = m.running_var.data[bn_dict["backbone.layer1.0.bn1"]]
#对结构进行修剪
m.num_features = bn_dict["backbone.layer1.0.bn1"].size()[0]

In addition, it is necessary to trim the output layer of the Conv in front of the BN layer, and trim the input layer of the Conv after the BN layer. The overall code:

#先得到所有BN层需要保留的权重索引
bn_dict = dict()
for name, m in model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        indice = find_indice(m, thresh=0.17)
        bn_dict[name] = indice
        
#进行剪枝
if name == "backbone.layer1.0.conv1":
    m.weight.data = m.weight.data[:, bn_dict["backbone.bn1"], :, :]
    m.weight.data = m.weight.data[bn_dict["backbone.layer1.0.bn1"], :, :, :]
    m.in_channels = bn_dict["backbone.bn1"].size()[0]
    m.out_channels = bn_dict["backbone.layer1.0.bn1"].size()[0]
if name == "backbone.layer1.0.bn1":
    m.weight.data = m.weight.data[bn_dict["backbone.layer1.0.bn1"]]  #gamma
    m.bias.data = m.bias.data[bn_dict["backbone.layer1.0.bn1"]]      #beta
    m.running_mean.data = m.running_mean.data[bn_dict["backbone.layer1.0.bn1"]]
    m.running_var.data = m.running_var.data[bn_dict["backbone.layer1.0.bn1"]]
    m.num_features = bn_dict["backbone.layer1.0.bn1"].size()[0]
if name == "backbone.layer1.0.conv2":
    m.weight.data = m.weight.data[:, bn_dict["backbone.layer1.0.bn1"], :, :]
    m.weight.data = m.weight.data[bn_dict["backbone.layer1.0.bn2"], :, :, :]
    m.in_channels = bn_dict["backbone.layer1.0.bn1"].size()[0]
    m.out_channels = bn_dict["backbone.layer1.0.bn2"].size()[0]

After cutting, save the model, and the printed model is the model with trimmed weight and structure:

torch.save(model, "修剪好的模型的保存路径")

Printing:
insert image description here
The pruned model can also be run directly.

The backbone of my own detection model is resnet50, and I only trimmed the backbone part. After trimming, the mAP dropped a little (related to the set threshold, and the trimming is much lower), but the effect after fine-tuning is better than before trimming. better. I temporarily interpret this phenomenon as the γ \gamma of the BN layerThe γ parameter can be used as an Attention mechanism. If someone can give a better explanation, please correct me.

1.3 Precautions

There are also some pits in the middle of pruning, record them here:

Resnet50 has some residual connections, which should be noted. At the beginning of each reslayer layer, there is another downsample layer, which consists of a Conv and a BN. If the input dimension of the reslayer is partially cut, remember to adjust the downsample layer too! ! !

Attach a resnet50 structure diagram:
insert image description here

Guess you like

Origin blog.csdn.net/weixin_45453121/article/details/130891939