Deep learning model pruning, quantization and TensorRT inference
Model pruning algorithm
- References: Rethinking the Value of Network Pruning (ICLR 2019)
- github:https://github.com/Eric-mingjie/rethinking-network-pruning
Rethinking the Value of Network Pruning This document mainly introduces the following pruning algorithms, and open source code on github, tested on ImageNet and cifar data sets, the paper also verified and compared the pruning model Perform fine tune and retrain the accuracy of the network based on the pruned network structure. In the accuracy provided by the author, in general, for the classification network, the accuracy of de novo training is higher than the accuracy of fine tune after pruning .
Pre-defined network structure pruning
Pruning method | references | github | Principle of Pruning |
---|---|---|---|
L1-norm based Filter Pruning | Pruning Filters for Efficient ConvNets | l1-norm-pruning | In each convolutional layer, according to the L1 norm size of the weight value of the convolution kernel, the corresponding percentage of the number of channels is trimmed. |
ThiNet | ThiNet: A Filter Level Pruning Method for Deep Neural Network Compression | https://github.com/Roll920/ThiNet | When pruning, the pruning is not performed according to the current layer, but the weight value that has the least impact on the activation value of the next convolutional layer is pruned. |
Regression based Feature Reconstruction | Channel Pruning for Accelerating Very Deep Neural Networks | https://github.com/yihui-he/channel-pruning | According to the LASSO regression algorithm, the representative convolutional layer channels are selected, the branches are reduced to remove the redundant channels, and the least square method is used to reconstruct the remaining channels. |
Automatic network structure pruning
Pruning method | references | github | Principle of Pruning |
---|---|---|---|
Network Slimming | (1) Learning Efficient Convolutional Networks through Network Slimming (2) SlimYOLOv3: Narrower, Faster and Better for Real-Time UAV Applications |
(1) network-slimming (2) https://github.com/PengyiZhang/SlimYOLOv3 |
During pruning, penalize the gamma parameters of the BN layer for sparse training. When pruning, select the parameters with larger gamma values of the BN layer after sparse training to retain |
Sparse Structure Selection | Data-Driven Sparse Structure Selection for Deep Neural Networks | https://github.com/TuSimple/sparse-structure-selection | In addition to sparse training of the channel, sparse training of the residual module can also be performed. |
- Sparse training code
def updateBN():
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(args.s*torch.sign(m.weight.data))
- Note: The pruning strategy for the residual module channel needs to be considered when pruning. The scheme includes no pruning (less parameters subtracted) or pruning according to a certain layer of the residual module, and there is also a pruning strategy for all residual channels. The mask performs an OR operation. The effect of pruning is shown in the figure below. When pruning, the number of channels of the two convolutional layers connected by shortcut must be equal.
- When using the network slimming algorithm, each convolutional layer must have at least 1 channel left.
TensorRT int8 quantization algorithm
Reference connection: 8-bit Inference with TensorRT
Quantitative overview
- Goal: Convert fp32 convolutional neural network to int8 without causing obvious loss of accuracy;
- Reason: int8 method has higher throughput and lower memory requirements;
- Challenge: The accuracy and dynamic range of int8 are significantly lower than fp32;
Dynamic Range | Minimum accuracy | |
---|---|---|
fp32 | -3.4 x 1038 ~ +3.4 x 1038 | 1.4 x 10-45 |
fp16 | -65504 ~ +65504 | 5.96 x 10-8 |
int8 | -128 ~ +127 | 1 |
- Solution: Minimize the loss of information when the weight of the trained model is converted into int8 and when int8 is calculated and activated;
- Result: The int8 method is implemented in TensorRT and does not require any additional fine tune or retraining.
- Question:
Why quantify and not use int8 for direct training?
Model training requires back-propagation and gradient descent. The hyperparameters during training are generally floating-point types, such as learning rate, etc. The int8 type cannot be trained.
Linear quantization
Formula:
Tensor Values = fp32 scale factor * int8 array
According to the quantization formula, only fp32 scale factor can be used for int8 quantization, then how to find fp32 scale factor?
As can be seen from the above figure, there are two int8 quantization methods, one is unsaturated mapping (left) and the other is saturated mapping (right). It has been verified that the unsaturated mapping will cause a serious loss of accuracy. The reason is that the positive and negative distributions calculated by the convolutional layer are very uneven. If the symmetrical unsaturated mapping is used (the original intention is to retain as much original information as possible), then There is an area on the +max side that is wasted. That is to say, after the scale is int8, the dynamic range of int8 is even smaller. An example of the limit is that there are no positive samples after quantization, and all the negative ones are piled up in a small Near the value, this will cause a serious loss of accuracy. The saturation mapping method is to first find a threshold T, and map all the values below the lowest threshold to -127, as shown in the three red points on the left of the upper right picture.
How to choose the quantization threshold?
1. Calibration steps:
- Perform fp32 inference on the calibration data set.
- For each convolutional layer:
- Collect activated histograms;
- Generate quantized distributions with different saturation thresholds;
- Choose a threshold that minimizes KL divergence.
- Generally speaking, the entire calibration process takes a few minutes
2. Selection of calibration data set:
- Representative
- diversification
- Ideally a subset of the validation data set
- 1000+ samples
Deep learning model to TensorRT
[1] Deep learning model PyTorch training and transfer to ONNX and TensorRT deployment
[2] darknet YOLOv4 model to ONNX to TensorRT deployment
[3] yolov5 PyTorch model to TensorRT
[4] CenterFace model to TensorRT
[5] RetinaFace MXNet model to ONNX to TensorRT