Pytorch ~ pruning model

pruning classification

The so-called model pruning is actually a model compression technique that removes "unnecessary" weights or biases (weigths/bias) from the neural network. As to what parameters are "unnecessary", this is an area that is still being researched.

unstructured pruning

Unstructured Puning refers to pruning a single element of a parameter, such as a single weight in a fully connected layer, a single convolution kernel parameter element in a convolutional layer, or scaling floats in a custom layer. The point is that the pruned weight objects are random and have no specific structure, hence the name unstructured pruning .

structured pruning

In contrast to unstructured pruning, structured pruning prunes the entire parameter structure. For example, dropping weights for entire rows or columns, or dropping entire filters ( ) in convolutional layers Filter.

Local vs. Global Pruning

Pruning can be done per layer (local) or on multiple/all layers (global).

Pruning in PyTorch

The weight pruning methods currently supported by the PyTorch framework are:

  • Random : Simply prune random parameters.

  • Magnitude : Prune the parameters with the smallest weight (eg their L2 norm)

The above two methods are simple to implement, easy to calculate, and can be applied without any data.

How pytorch pruning works

The pruning function  torch.nn.utils.prune is implemented in the class, and the code is in the file torch/nn/utils/prune.py. The main pruning class is shown in the figure below.

pytorch_pruning_api_file.png

The principle of pruning is based on the mask (Mask) implementation of tensor (Tensor). The mask is a Boolean tensor with the same shape as the tensor. The value of the mask is True, indicating that the weight of the corresponding position needs to be preserved, and the value of the mask is False, indicating that the weight of the corresponding position can be deleted.

Pytorch  <param> copies the original parameters into  <param>_original a parameter named , and creates a buffer to store the pruning mask  <param>_mask. At the same time, it also creates a module-level forward_pre_hook callback function (a callback function that will be called before the model is forward-propagated) to apply the pruning mask to the original weights.

The pytorch pruning  api and tutorials are quite confusing. I personally will make the following table, hoping to summarize the api, pruning methods and classifications.

The workflow of model pruning in pytorch is as follows:

  1. Select the pruning method (or subclass BasePruningMethod to implement your own pruning method).

  2. Specify the pruning module and parameter name.

  3. Set the parameters of the pruning method, such as the pruning ratio.

local pruning

There are two types of local pruning in the Pytorch framework: unstructured and structured pruning. It is worth noting that structured pruning only supports local but not global.

2.2.1, local unstructured pruning

1. The prototype of the function corresponding to Local Unstructured Pruning is as follows :

def random_unstructured(module, name, amount)  

1. Function function :

Used for unstructured pruning of weight parameter tensors . This method randomly selects some weights or connections in the tensor for pruning, and the pruning rate is specified by the user.

2. Function parameter definition:

  • module (nn.Module): The network layer/module that needs to be pruned, such as nn.Conv2d() and nn.Linear().

  • name (str): The name of the parameter to prune, such as "weight" or "bias".

  • amount (int or float): Specifies the quantity to be pruned. If it is a decimal between 0 and 1, it indicates the pruning ratio; if it is a certificate, the absolute quantity of the parameter is directly cut off. For example amount=0.2 , it means that 20% of the elements will be randomly selected for pruning.

3. The following is  random_unstructured an example of using the function.

import torch  
import torch.nn.utils.prune as prune  
conv = torch.nn.Conv2d(1, 1, 4)  
prune.random_unstructured(conv, name="weight", amount=0.5)  
conv.weight  
"""  
tensor([[[[-0.1703,  0.0000, -0.0000,  0.0690],  
          [ 0.1411,  0.0000, -0.0000, -0.1031],  
          [-0.0527,  0.0000,  0.0640,  0.1666],  
          [ 0.0000, -0.0000, -0.0000,  0.2281]]]], grad_fn=<MulBackward0>)  
"""  

Half of the weight value in the conv layer output by reading the book is  0.

2.2.2, local structured pruning

Local structured pruning (Locall Structured Pruning) has two functions, and the corresponding function prototypes are as follows:

def random_structured(module, name, amount, dim)  
def ln_structured(module, name, amount, n, dim, importance_scores=None)  

1. Function function

Unlike unstructured pruning, which removes connection weights, structured pruning removes entire channel weights.

2. Parameter definition

Very similar to local unstructured functions, the only difference is that you must define the dim parameter (the ln_structured function has more  n parameters).

n Represents the norm of pruning, dim and represents the dimension of pruning.

For torch.nn.Linear:

  • dim = 0: Remove a neuron.

  • dim = 1: Removes all connections to an input.

For torch.nn.Conv2d:

  • dim = 0(Channels) : channels pruning / filters pruning

  • dim = 1(Neurons): Two-dimensional convolution kernel kernel pruning, that is, the kernel connected to the input channel

2.2.3, Local structured pruning sample code

Before writing the sample code, we need to understand  Conv2d the relationship between function parameters, convolution kernel shape, axis and tensor.

First, the Conv2d function prototype is as follows;

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)  

In pytorch, the weight of the convolution kernel of the conventional convolution  shape is ( C_out, C_in, kernel_height, kernel_width), so the weight of the convolution layer in the code  shape is  [3, 2, 3, 3], and dim = 0 corresponds to the shape [3, 2, 3, 3]  3. Here we dim which axis is set, then the axis corresponding to the weight tensor will change after natural pruning.

After understanding the previous key concepts, the following can be used in practice, dim=0 as shown below.

conv = torch.nn.Conv2d(2, 3, 3)  
norm1 = torch.norm(conv.weight, p=1, dim=[1,2,3])  
print(norm1)  
"""  
tensor([1.9384, 2.3780, 1.8638], grad_fn=<NormBackward1>)  
"""  
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=0)  
print(conv.weight)  
"""  
tensor([[[[-0.0005,  0.1039,  0.0306],  
          [ 0.1233,  0.1517,  0.0628],  
          [ 0.1075, -0.0606,  0.1140]],  
  
         [[ 0.2263, -0.0199,  0.1275],  
          [-0.0455, -0.0639, -0.2153],  
          [ 0.1587, -0.1928,  0.1338]]],  
  
  
        [[[-0.2023,  0.0012,  0.1617],  
          [-0.1089,  0.2102, -0.2222],  
          [ 0.0645, -0.2333, -0.1211]],  
  
         [[ 0.2138, -0.0325,  0.0246],  
          [-0.0507,  0.1812, -0.2268],  
          [-0.1902,  0.0798,  0.0531]]],  
  
  
        [[[ 0.0000, -0.0000, -0.0000],  
          [ 0.0000, -0.0000, -0.0000],  
          [ 0.0000, -0.0000,  0.0000]],  
  
         [[ 0.0000,  0.0000,  0.0000],  
          [-0.0000,  0.0000,  0.0000],  
          [-0.0000, -0.0000, -0.0000]]]], grad_fn=<MulBackward0>)  
"""  

It can be clearly seen from the running results that the last channel parameter tensor of the convolutional layer parameters has been removed (it is a  0 tensor). See the figure below for its explanation.

dim = 1 Case:

conv = torch.nn.Conv2d(2, 3, 3)  
norm1 = torch.norm(conv.weight, p=1, dim=[0, 2,3])  
print(norm1)  
"""  
tensor([3.1487, 3.9088], grad_fn=<NormBackward1>)  
"""  
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=1)  
print(conv.weight)  
"""  
tensor([[[[ 0.0000, -0.0000, -0.0000],  
          [-0.0000,  0.0000,  0.0000],  
          [-0.0000,  0.0000, -0.0000]],  
  
         [[-0.2140,  0.1038,  0.1660],  
          [ 0.1265, -0.1650, -0.2183],  
          [-0.0680,  0.2280,  0.2128]]],  
  
  
        [[[-0.0000,  0.0000,  0.0000],  
          [ 0.0000,  0.0000, -0.0000],  
          [-0.0000, -0.0000, -0.0000]],  
  
         [[-0.2087,  0.1275,  0.0228],  
          [-0.1888, -0.1345,  0.1826],  
          [-0.2312, -0.1456, -0.1085]]],  
  
  
        [[[-0.0000,  0.0000,  0.0000],  
          [ 0.0000, -0.0000,  0.0000],  
          [ 0.0000, -0.0000,  0.0000]],  
  
         [[-0.0891,  0.0946, -0.1724],  
          [-0.2068,  0.0823,  0.0272],  
          [-0.2256, -0.1260, -0.0323]]]], grad_fn=<MulBackward0>)  
"""  

Obviously, for  dim=1the dimension of , the L2 norm of the first tensor is smaller, so in the tensor with shape [2, 3, 3], the first [3, 3] tensor parameter will be removed (ie Tensor is 0 matrix).

2.3, global unstructured pruning

The object of the local pruning mentioned above is a specific network layer, while the global pruning is to treat the model as a whole to remove the parameters of the specified ratio (number), and the result of global pruning will cause the sparse ratio of each layer in the model to be different. the same. The prototype of the global unstructured pruning function is as follows:

# v1.4.0 版本  
def global_unstructured(parameters, pruning_method, **kwargs)  
# v2.0.0-rc2版本  
def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):  

1. Function function :

Randomly select a fraction of all parameters globally (including weights and biases) for pruning, regardless of which layer they belong to.

2. Parameter definition :

  • parameters((Iterable of (module, name) tuples)): Prune the parameter list of the model, the element in the list is (module, name).

  • pruning_method(function): At present, it seems that only pruning_method=prune.L1Unstructured is officially supported, and it can also be an unstructured pruning method function implemented by oneself.

  • importance_scores: Indicates the importance score of each parameter, if None, the default score is used.

  • **kwargs: Indicates extra parameters passed to a specific pruning method. For example,  amount specify the number of branches to be pruned.

3. global_unstructured The sample code of the function is as follows.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
  
class LeNet(nn.Module):  
    def __init__(self):  
        super(LeNet, self).__init__()  
        # 1 input image channel, 6 output channels, 3x3 square conv kernel  
        self.conv1 = nn.Conv2d(1, 6, 3)  
        self.conv2 = nn.Conv2d(6, 16, 3)  
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension  
        self.fc2 = nn.Linear(120, 84)  
        self.fc3 = nn.Linear(84, 10)  
  
    def forward(self, x):  
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))  
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)  
        x = x.view(-1, int(x.nelement() / x.shape[0]))  
        x = F.relu(self.fc1(x))  
        x = F.relu(self.fc2(x))  
        x = self.fc3(x)  
        return x  
  
model = LeNet().to(device=device)  
  
model = LeNet()  
  
parameters_to_prune = (  
    (model.conv1, 'weight'),  
    (model.conv2, 'weight'),  
    (model.fc1, 'weight'),  
    (model.fc2, 'weight'),  
    (model.fc3, 'weight'),  
)  
  
prune.global_unstructured(  
    parameters_to_prune,  
    pruning_method=prune.L1Unstructured,  
    amount=0.2,  
)  
# 计算卷积层和整个模型的稀疏度  
# 其实调用的是 Tensor.numel 内内函数,返回输入张量中元素的总数  
print(  
    "Sparsity in conv1.weight: {:.2f}%".format(  
        100. * float(torch.sum(model.conv1.weight == 0))  
        / float(model.conv1.weight.nelement())  
    )  
)  
print(  
    "Global sparsity: {:.2f}%".format(  
        100. * float(  
            torch.sum(model.conv1.weight == 0)  
            + torch.sum(model.conv2.weight == 0)  
            + torch.sum(model.fc1.weight == 0)  
            + torch.sum(model.fc2.weight == 0)  
            + torch.sum(model.fc3.weight == 0)  
        )  
        / float(  
            model.conv1.weight.nelement()  
            + model.conv2.weight.nelement()  
            + model.fc1.weight.nelement()  
            + model.fc2.weight.nelement()  
            + model.fc3.weight.nelement()  
        )  
    )  
)  
# 程序运行结果  
"""  
Sparsity in conv1.weight: 3.70%  
Global sparsity: 20.00%  
"""  

The running results show that although the overall (global) sparsity of the model is  20%, the sparsity of each network layer is not necessarily 20%.

Summarize

In addition, the pytorch framework also provides some helper functions:

  1. torch.nn.utils.prune.is_pruned(module): Determine whether the module has been pruned.

  2. torch.nn.utils.prune.remove(module, name): Used to remove the pruning operation on the specified parameter in the specified module , thereby restoring the original shape and value of the parameter.

Although PyTorch provides built-in pruning  API , it also supports some unstructured and structured pruning methods, but  API it is confusing and the corresponding document description is not clear, so I will combine Microsoft's open source  nni tools to realize the model pruning function later.

For more pruning method practice, you can refer to this  github warehouse: Model-Compression.

whaosoft aiot http://143ai.com

Guess you like

Origin blog.csdn.net/qq_29788741/article/details/130694014