[Deep Learning] [pytorch] Perform real pruning on convolutional layers with zero convolution kernels

Recently, deep learning models need to be deployed, so the models need to be compressed. The blogger has selected many blog posts from big guys and personally tested their effectiveness. He shares his notes and invites everyone to study and discuss together.


Preface

Deep learning pruning (Pruning) is a technology used to reduce the size of neural network models, reduce the amount of calculations, and improve reasoning efficiency by removing redundant connections (weights) or nodes (neurons) in the neural network ), thereby achieving sparseness of the model.
Deep learning pruning (Pruning) has the following benefits: 1. Model compression and storage saving; 2. Saving of computing resources; 3. Accelerating inference speed; 4. Preventing overfitting.
"Fake Pruning" is the name of a pruning algorithm. It does not actually delete weights or nodes during the pruning process, but sets them to zero through some techniques. Or disabled to simulate the effect of pruning. Many excellent papers have adopted the "false pruning" strategy. Although it can improve the inference speed of the model to a certain extent, the false pruning algorithm does not really reduce the size of the model. , the blogger will explain a simple and easy-to-understand method of truly pruning the "fake pruning" convolutional layer by explaining a small case.


Convolution layer pruning

You can first copy the final complete code to your own py file, and then follow the blogger's ideas to learn how to actually prune the zeroed convolution kernel:

  1. Initialize the convolutional layer and view the convolutional layer weights
    # 示例使用一个具有3个输入通道和5个输出通道的卷积层
    conv = nn.Conv2d(3, 5, 3)
    print("原始卷积层权重:")
    print(conv.weight.data)
    print(conv.weight.size())
    print("原始卷积层偏置:")
    print(conv.bias.data)
    print(conv.bias.size())
    
  2. The random function is used to reset some convolution kernel weights to 0, and the simulation completes false pruning.
    # remove_zero_kernels方法内的代码
    weight = conv_layer.weight.data
    # 卷积核个数
    num_kernels = weight.size(0)
    # 随机对部分卷积置0
    pruned = torch.ones(num_kernels, 1, 1, 1)
    # 选择随着置0的卷积序号
    random_int = random.randint(1, num_kernels-1)
    for i in range(random_int):
        pruned[i, 0, 0, 0] = 0
    conv_layer.weight.data = weight * pruned
    weight = conv_layer.weight.data
    bias = conv_layer.bias.data
    
  3. Save the weights and biases of unpruned convolution kernels
    # 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了
    norms = torch.norm(weight.view(num_kernels, -1), dim=1)
    zero_kernel_indices = torch.nonzero(norms==0).squeeze()
    print(zero_kernel_indices)
    # 移除L2范数为零的卷积核
    new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])
    new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])
    
  4. Construct a new convolution layer to replace the previous convolution layer and complete the real pruning of the zero-set convolution kernel.
    # 构建新的卷积层
    if zero_kernel_indices.numel() > 0:
        # 输入channel
        in_channels = weight.size(1)
        # 输出channel
        out_channels = new_weight.size(0)
        # 卷积核大小
        kernel_size = weight.size(2)
        # 步长
        stride = conv_layer.stride
        padding = conv_layer.padding
        dilation = conv_layer.dilation
        groups = conv_layer.groups
        new_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)
        new_conv_layer.weight.data = new_weight
        new_conv_layer.bias.data = new_bias
    else:
        new_conv_layer = conv_layer
    

Complete code

import torch
import torch.nn as nn
import random

def remove_zero_kernels(conv_layer):
    # 卷积核权重
    weight = conv_layer.weight.data
    # 卷积核个数
    num_kernels = weight.size(0)
    # 随机对部分卷积置0
    pruned = torch.ones(num_kernels, 1, 1, 1)
    # 选择随着置0的卷积序号
    random_int = random.randint(1, num_kernels-1)
    for i in range(random_int):
        pruned[i, 0, 0, 0] = 0
    conv_layer.weight.data = weight * pruned
    weight = conv_layer.weight.data
    bias = conv_layer.bias.data
    # 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了
    norms = torch.norm(weight.view(num_kernels, -1), dim=1)
    zero_kernel_indices = torch.nonzero(norms==0).squeeze()
    print(zero_kernel_indices)
    # 移除L2范数为零的卷积核
    new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])
    new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])
    # 构建新的卷积层
    if zero_kernel_indices.numel() > 0:
        # 输入channel
        in_channels = weight.size(1)
        # 输出channel
        out_channels = new_weight.size(0)
        # 卷积核大小
        kernel_size = weight.size(2)
        # 步长
        stride = conv_layer.stride
        padding = conv_layer.padding
        dilation = conv_layer.dilation
        groups = conv_layer.groups
        new_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)
        new_conv_layer.weight.data = new_weight
        new_conv_layer.bias.data = new_bias
    else:
        new_conv_layer = conv_layer

    return new_conv_layer

# 示例使用一个具有3个输入通道和5个输出通道的卷积层
conv = nn.Conv2d(3, 5, 3)
# print("原始卷积层权重:")
# print(conv.weight.data)
# print(conv.weight.size())
# print("原始卷积层偏置:")
# print(conv.bias.data)
# print(conv.bias.size())

# 将置零的卷积核移除
new_conv = remove_zero_kernels(conv)
# print("原始卷积层权重:")
# print(new_conv.weight.data)
# print(new_conv.weight.size())
# print("原始卷积层偏置:")
# print(new_conv.bias.data)
# print(new_conv.bias.size())

Summarize

The blogger's idea is to initialize a new convolution layer with the retained (unpruned) weights in the convolution layer, so that the falsely pruned zero-set convolution kernels are truly removed. Are there any readers who have studied this aspect? You can share other methods with bloggers and make progress together.

Guess you like

Origin blog.csdn.net/yangyu0515/article/details/134185435