Towards Optimal Structured CNN Pruning via Generative Adversarial Learning论文笔记

论文地址: https://arxiv.org/abs/1903.09291
github地址:https://github.com/ShaohuiLin/GAL

本文提出基于生成对抗的剪枝策略,这与我目前的方法有些相似。

Motivation

目前的结构剪枝方法存在一些不足:

  • 相对耗时的多阶段优化过程,须反复迭代地剪枝和微调
  • 一般剪枝均为硬剪枝,采用hard pruning mask,无法灵活地优化学习过程
  • 训练和正则化过程依赖于样本target

针对存在的这些不足,本文提出基于生成对抗的剪枝策略(GAL, Generative Adversarial Learning)。

Method

该剪枝策略的框架图如图所示。
框架图
首先将剪枝后的网络看成生成器(Generator),其输出特征通过判决器标为fake,并设置软掩膜(soft mask)控制输出路径;接着将预训练模型设置为baseline,其输出特征通过判决器标为real;同时引入正则化约束,一方面判决器使生成器学习baseline的输出结果的分布,另一方面正则化约束是的生成器中的软掩膜稀疏化(相当于剪枝),最终达到低精度损失的结构剪枝目的。所以,可以看出,该方法为剪枝分类中的adaptive property,能够自动完成网络最优结构的搜索,并且使用了KDGAN的框架方法,将大网络的特征分布传授给小网络。
在对抗学习过程中,Baseline的参数固定,而剪枝模型参数、软掩膜以及判别器参数在训练中更新。
训练过程主要包含两个交替的阶段:

  1. 第一个阶段固定生成器和掩膜,通过对抗训练更新判别器D,损失函数包含对抗损失与对抗正则项;
  2. 第二阶段固定判决器,更新生成器与掩膜,损失函数包含对抗损失中的生成器与baseline特征输出的MSE损失以及生成器和掩膜的正则项。

最终,根据掩膜的阈值和门控方式,对channel、branch或block进行剪枝,从而实现模型的压缩。

Experiment

数据集:MNIST, CIFAR-10, ImageNet ILSVRC 2012
GPU: NVIDIA GTX 1080Ti ×2 (128GB)
模型:(LeNet, VGGNet, DenseNet) with channel selection, GoogLeNet with branch selection, ResNets with block selection and channel selection
优化器,学习率等参数设置因不同模型而不同,故具体细节见论文。

Results

结果一
结果二
结果三

Thoughts

这篇文章也是利用生成对抗来学习压缩网络,值得学习的是一些参数的设置以及mask的训练过程。另外在这当篇文章中,作者用MSE将baseline和剪枝网络的输出进行对齐监督,但我觉得可以进一步对中间层也进行对其监督

猜你喜欢

转载自blog.csdn.net/qq_43812519/article/details/104887164