网络模型剪枝--学习笔记

网络模型剪枝

一、 什么是模型剪枝

神经网络中的一些权重和神经元是可以被剪枝的,这是因为这些权重可能为零或者神经元的输出大多数时候为零,表明这些权重或神经元是冗余的。
网络剪枝能够有压缩模型、简化模型的优点,自然也会有一定的缺点。最核心的便是精度会有所下降,如果剪枝做的不好,那么精度会下降的非常厉害,如何解决这个问题便是后面学者所研究的方向。当然如果在pruning后精度提高的,这说明原模型过似合(overfit)了,pruning起到了regularization的作用。
**

二、 剪枝分类

**
从network pruning的粒度来说可以分为结构化剪枝(Structured pruning)和非结构化剪枝Unstructured pruning)两类
在这里插入图片描述

(1) 结构化剪枝
权重剪枝:
剪掉神经元节点之间的不重要的连接。相当于把权重矩阵中的单个权重值设置为0。一般的,会对权重
矩阵中所有的数值按照大小排序,把排在后面的一定比例的值设为0即可。
在这里插入图片描述

(2) 非结构化剪枝
神经元剪枝:
把权重矩阵中某个神经元节点去掉,则和神经元相连接的突触也要全部去除。相当于同时去除权重矩阵中的某一行和列。如何判断神经元节点的重要程度呢?可以通过计算神经元对应的行和列的权重值的平方和的根的大小进行排序,把排序在后面一定比例的神经元节点去掉
在这里插入图片描述

三、 卷积的结构化剪枝

(1) Filter-wise
一个卷积核被剪枝,那么其前一个Feature Map和下一个Feature Map 都会发生相应的变化
在这里插入图片描述

以上图为例,在第i层卷积中,其中第2、5个卷积核被剪掉(卷积核数量变少,每个卷积核的shape不变);当i-1层的featruemap经过第i层卷积矩阵卷积后得到的第i层feature map,其中的第2、5个channel也相应的被去除。为了匹配第i层featuremap通道维度产生的变化,第i+1层的卷积中的每个卷积核的第2、5个channel的权重被去除(卷积核数量不变,但每个卷积的shape发生改变)。为方便观察,把上图中每层卷积核排列成卷积核矩阵的形状,如下图kernel matrix。通过这种形式,我们继续探讨两种剪枝形式。

(2) 单层中卷积核剪枝
在这里插入图片描述

如上图所示,kernel matrix 中的ni表示第i层feature map的通道深度;ni+1表示第i+1层feature map的通道深度。kernelmatrix中每个卷积核的尺寸为k x k。从第i个卷积层剪掉n个卷积核的算法过程如下:(1)计算每个卷积核的权重绝对值之和。(2)根据的值大小排序。(3)将最小的n个卷积核及对应的feature map剪掉。下一个卷积层中相关的卷积核也要移除。(4)生成了第i层和第i+1层新的权重矩阵,剩余的权重参数被复制到新模型中
(3) Shape-wise
filter-wise剪枝是对完整的卷积核(kxkxc)进行剪枝;channel-wise是对所有卷积核中相同layer进行剪枝。shape-wise的剪枝颗粒度相对而言则更小一些。剪枝对象是所有卷积核中相同位置的部分权重的剪枝。如下图所示可以看到三种剪枝的区别。
在这里插入图片描述

由于每个卷积核中不重要的权重值的位置并不相同,这种剪枝方式可能会导致模型网络丢掉有效的信息。
(4) 剪枝的算法依据
以《Learning Efficient Convolutional Networks Through Network Slimming》为例。conv-layer的每个channel的重要程度可以和batchnorm层关联起来,如果某个channel后的batchnorm层中对应的scaling factor足够小,就说明该channel的重要程度低,可以被忽略。
在这里插入图片描述

如上图中橙色的两个通道被剪枝
batchnorm的公式如下所示:
在这里插入图片描述

其中: 在这里插入图片描述
表示channel scaling factors。为了增加的稀疏程度,方便对channel进行剪枝,训练时需要对每个batchnorm层的scaling factor增加L1的约束。channel-wise和filter-wise既有区别,也有联系。两者使用的剪枝评判方法不同,但最终都会体现在对卷积核或卷积核中某些layer的剪枝。

四、 Torch官方的剪pruning算法讲解

深度学习技术依赖于过参数化模型,这不利于部署。相反,生物神经网络是使用高效的稀疏连接的。通过减少模型中的参数数量来压缩模型的技术非常重要
为了减少内存、电池和硬件消耗,而减少牺牲准确性,在设备上部署轻量级模型。
学习如何使用 torch.nn.utils.prune.
(1) 剪枝

在这里插入图片描述
在这里插入图片描述

(2) 迭代剪枝
在这里插入图片描述

(3) 去掉剪枝的重参数
在这里插入图片描述

(4) 全局剪枝
在这里插入图片描述

*

五、 Torch-Pruning工具箱(结构化剪枝)

是进行结构剪枝的pytorch工具箱. 和pytorch官方提供的基于mask的非结构化剪枝不同, 工具箱移除整个剪枝通道. 自动发现层与层剪枝的依赖关系,可以处理DenseNet, ResNet and DeepLab.
在这里插入图片描述

特性:
卷积网络通道剪枝CNNs (e.g. ResNet, DenseNet, Deeplab) 和 Transformers (即 Bert)
网络图跟踪以及依赖关系(Dependency).
支持网络层: Conv, Linear, BatchNorm, LayerNorm, Transposed Conv, PReLU, Embedding.
支持操作: split, concatenation, skip connection, flatten, 等等. 剪枝策略: Random, L1, L2, 等等.
剪枝策略: Random, L1, L2, 等等.
Torch-Pruning 使用 fake inputs输入网络和 torch.jit一样收集网络信息. dependency graph 用来表示计算图和层之间的关系. 由于裁剪一层会影响若干层 , dependecy会自动传播剪枝到其他层并且保存在PruningPlan. 如果模型中有 torch.split或者torch.cat,所有剪枝的indices都会做一些变换的. 
Conv-Conv:oc中减少 这1个通道,下一个卷积每个卷积核的ic通道减少 这一个
Skip Connection:需要考虑ic和上一层的oc互相关联,所以这里shortcut和add都需要传递这种关联

	结构化剪枝中,必须保持各层剪枝的兼容性:

在这里插入图片描述
在这里插入图片描述

	使用例子:

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_43391596/article/details/128915474