Discriminator-Cooperated Feature Map Distillation for GAN Compression

CVPR2023| Discriminator-Cooperated Feature Map Distillation for GAN Compression

  • 论文链接:https://arxiv.org/pdf/2212.14169.pdf
  • 源码链接:https://github.com/poopit/DCD-official

简介

图像生成任务将随机噪声或源域图像转换到另一个用户需要域的图像。近些年GAN蓬勃发展,产生了大量图像到图像翻译、风格迁移、图像生成等研究。图像生成在日常娱乐上由广泛应用。然而运行这些程序的平台通常有较差内存存储和有限计算能力的特点。然而GAN也因可学习参数和乘法累加运算(MAC)的激增而臭名昭著,这对部署基础设施的存储需求和计算能力提出了巨大挑战。

为了解决上述问题使GAN有更好的服务生活能力,例如剪枝神经网络搜索量化的方法已经广泛探索以获得一个更小的生成器。在这些模型压缩研究的前提下,知识蒸馏特别是特征图知识蒸馏,已被认为使提高压缩生成器性能的补充手段。

与在图像分类任务中的实现,基于特征图的蒸馏方法也在GAN压缩中考虑。例如,DMAD考虑一个预训练的判别器从教师生成图像中高级别信息,与教师生成器的中间激活混合,将结果传递到学生生成器相应位置。OMGD使用在线多粒度策略,允许更深层次教师生成器和更广泛生成器同时项学生生成器提供不同粒度的输出图像知识。

具体地,与依赖于特征向量表示的图像分类相反,图像生成的本质是改善真实图像与生成图像间感知相似性。两个重要事实导致很难使用每个像素匹配分析一堆图像:1. 两个相似的图像可以包含很多不同像素值,2. 两个不同图像仍然可以包括相同像素值。因此简单地使用每像素匹配是不合适的。关于GANs的对抗训练,生成器学习合成与数据集最相似的样本,同时鉴别器将数据集中样本和生成的样本区分开。对抗性的结果最终导致生成器创造出一场视觉质量的图像,这表明鉴别器也具有信息能力,可以用来丰富特征图的提取。这表明直接从图像分类扩展特征图蒸馏到图像生成时不合适的。
img

本文方法

相关定义

GAN是通过用包括生成器模型 G \mathcal{G} G和鉴别器模型 D \mathcal{D} D的两个子模型对问题进行建模来训练生成器模型的一种有趣方式。
min ⁡ G max ⁡ D L g a n = E y ∼ p r e a l log ⁡ D ( y ) + E x ∼ p ( x ) [ log ⁡ ( 1 − D ( G ( x ) ) ) ] \min_{\mathcal{G}}\max_{\mathcal{D}}\mathcal{L}_{gan} = \mathbb{E}_{y\sim{p_{real}}} \log \mathcal{D}(y) + \mathbb{E}_{x\sim{p(x)}}[\log (1-\mathcal{D}(\mathcal{G(x)}))] GminDmaxLgan=EypreallogD(y)+Exp(x)[log(1D(G(x)))]
生成模型训练用于生成新示例,判别器模型 D \mathcal{D} D尝试分类示例真实的或虚假的(生成的)。

给定输入变量 x ∼ p ( x ) ∈ R H × W × C x\sim{p(x)}\in \mathbb{R}^{H\times W\times C} xp(x)RH×W×C,第 i i i层生成器输出为 G i ( x ) \mathcal{G}_{i}(x) Gi(x),提取中间输出层索引 I G ( x ) \mathcal{I}_{\mathcal{G}}(x) IG(x)。基于特征图的知识蒸馏损失可以描述为:
L f e a − d i s = ∑ i ∈ I G ℓ ( G i T ( x ) , f ( G i S ( x ) ) ) \mathcal{L}_{fea-dis} = \sum_{i\in I_{\mathcal{G}}} \ell(\mathcal{G}_{i}^{T}(x),f(\mathcal{G}_{i}^{S}(x))) Lfeadis=iIG(GiT(x),f(GiS(x)))
f ( ⋅ ) f(\cdot) f()是仿射变换函数以对齐教师和学生模型之间通道维度。

L p e r \mathcal{L}_{per} Lper在现有研究中广泛使用,以鼓励自然和令人愉悦的修复结果。 L p e r \mathcal{L}_{per} Lper由一个特征重建损失 L f e a \mathcal{L}_{fea} Lfea和风格重建损失 L s t y \mathcal{L}_{sty} Lsty组成:
L p e r = λ f e a ⋅ L f e a + λ s t y ⋅ L s t y \mathcal{L}_{per} = \lambda_{fea}\cdot \mathcal{L}_{fea}+\lambda_{sty}\cdot \mathcal{L}_{sty} Lper=λfeaLfea+λstyLsty
L f e a \mathcal{L}_{fea} Lfea推动教师生成器的输出表示接近学生生成器表示。这由一个预训练VGG网络 Φ ( ⋅ ) \Phi(\cdot) Φ()获得:
L f e a = ∑ j ∈ I Φ 1 H j W j C j ∣ ∣ Φ j ( G T ( x ) ) − Φ j ( G S ( x ) ) ∣ ∣ 1 \mathcal{L}_{fea} = \sum_{j\in I_{\Phi}}\frac{1}{H_{j}W_{j}C_{j}}||\Phi_{j}(\mathcal{G}^{T}(x)) - \Phi_{j}(\mathcal{G}^{S}(x))||_{1} Lfea=jIΦHjWjCj1∣∣Φj(GT(x))Φj(GS(x))1
对于 L s t y \mathcal{L}_{sty} Lsty,它最小化输出和目标图像的Gram矩阵差异用于保护风格特征例如颜色、纹理和常见范式。

判别器合作蒸馏

回顾基于特征图的蒸馏方法,意识到早期方法中生成器容量简单利用是不完整。GAN原理却决于通过动态更新的鉴别器的间接训练路线,以辨别器生成器输出看起来多真实。这意味着生成器不是训练来缩小生成图像和目标图像间差异,而是欺骗判别器。生成器和鉴别器间合作竞争模式甚至带来了表面上真实的生成图像。因此鉴别器也具有信息能力,必须用于丰富特征图提取。

本文重新思考特征图蒸馏损失,并集成教师鉴别器配合蒸馏过程。类似于前文定义的特征重建损失,在生成器输出作为输入同时,通过对鉴别器中间输出完成本文蒸馏过程。
L d c d = ∑ k ∈ I D ∑ i ∈ I G ℓ ( D k T ( f ( G i T ( x ) ) ) , D k T ( f ( G i S ( x ) ) ) ) \mathcal{L}_{dcd} = \sum_{k\in I_{\mathcal{D}}}\sum_{i\in I_{\mathcal{G}}} \ell(\mathcal{D_{k}}^{T}(f(\mathcal{G}_{i}^{T}(x))),\mathcal{D_{k}}^{T}(f(\mathcal{G}_{i}^{S}(x)))) Ldcd=kIDiIG(DkT(f(GiT(x))),DkT(f(GiS(x))))
类似于迫使教师生成器和学生生成器的中间输出之间的每像素匹配的普通特征图的蒸馏,教师鉴别器作用类似于预训练的VGG,充当转换网络,使学生生成器的中间输出在感知上与教师生成器中间结果相似。相比之下,它不需要逐像素匹配。本文的判别器合作蒸馏是对感知损失的补充。

合作对抗训练

GAN为独特的全局平衡点执行交替训练范式:1. 训练判别器 D \mathcal{D} D判断真实和生成样本并保持生成器不变,2. 训练生成器产生可以愚弄判别器的生动数据并保持判别器不变,3. 重复1,2直到大部分时候判别器被愚弄。然而在 D \mathcal{D} D G \mathcal{G} G被赋予不一致能力时不能保证平衡,且经常发生不稳定的收敛。一般的,不稳定问题来源于,1. 梯度消失:当判别器完美时生成器损失为0,2. 模式崩坏:较强的生成器为任何输入产生一小组输出,较弱的判别器陷入局部最小值中。模式崩坏问题在GAN压缩中普遍存在,原因是压缩后的生成器 G S \mathcal{G}^{S} GS无力与原始的鉴别器竞争。目前已经提出很多削弱学生判别器的方法:GCC选择性激活判别器神经元。然而必须仔细设计经验法则的选择过程。此外训练学生鉴别器计算上是多余的,因为它在测试阶段是不需要的。OMGD算法中共同训练教师判别器和教师生成器且不存在学生判别器。学生判别器缺失以一定方式阻碍了进一步性能提升。本文方法的教师判别器也以协作判别器形式出现,以确定学生生成器给的输入是真的还是假的。主要问题源于教师判别器比压缩后的学生生成器强大得多。幸运的是,本文发现判别器合作蒸馏为学生生成器提供了增强的对抗教师判别器能力。
min ⁡ G max ⁡ D L c o l = E y ∼ p r e a l log ⁡ D ( y ) + E x ∼ p ( x ) [ log ⁡ ( 1 − D ( G ( x ) ) ) ] + λ s t u ⋅ E x ∼ p ( x ) [ log ⁡ ( 1 − D T ( G S ( x ) ) ) ] \min_{\mathcal{G}}\max_{\mathcal{D}}\mathcal{L}_{col} = \mathbb{E}_{y\sim{p_{real}}} \log \mathcal{D}(y) + \mathbb{E}_{x\sim{p(x)}}[\log (1-\mathcal{D}(\mathcal{G(x)}))] + \lambda_{stu}\cdot \mathbb{E}_{x\sim{p(x)}}[\log (1-\mathcal{D}^{T}(\mathcal{G}^{S}(x)))] GminDmaxLcol=EypreallogD(y)+Exp(x)[log(1D(G(x)))]+λstuExp(x)[log(1DT(GS(x)))]

训练目标

在大部分传统基于特征图蒸馏算法的损失项包括 L g a n \mathcal{L}_{gan} Lgan L f e a − g a n \mathcal{L}_{fea-gan} Lfeagan L p e r \mathcal{L}_{per} Lper。本文通过判别器合作特征图蒸馏 L d c d \mathcal{L}_{dcd} Ldcd改进 L f e a − g a n \mathcal{L}_{fea-gan} Lfeagan,通过合作对抗训练损失 L c o l \mathcal{L}_{col} Lcol改进 L g a n \mathcal{L}_{gan} Lgan
min ⁡ G T , G S max ⁡ D T ( L g a n + L p e r + λ d c d ⋅ L d c d ) \min_{\mathcal{G}^{T},\mathcal{G}^{S}}\max_{\mathcal{D}^{T}}(\mathcal{L}_{gan} + \mathcal{L}_{per} + \lambda_{dcd}\cdot\mathcal{L}_{dcd}) GT,GSminDTmax(Lgan+Lper+λdcdLdcd)

猜你喜欢

转载自blog.csdn.net/qgh1223/article/details/130375829