Modelo GAN de entrada de base cero de PyTorch cGAN

Introducción de fondo

En artículos anteriores, presentamos los principios de las GAN y cómo evaluar modelos entrenados . Algunos de ustedes habrán visto cómo se generan las imágenes de una sola categoría, por ejemplo, CIFAR10 e ImageNet contienen imágenes de múltiples categorías, si quiero entrenar una red antagónica generativa que pueda generar imágenes de múltiples categorías, ¿qué debo hacer? ¿Paño de lana?
Entonces, los modelos llamados GAN condicionales (cGAN) pueden ser útiles.

Desarrollo de cGAN

1. Cómo introducir información de categoría

El primer cGAN se propuso en el documento "Redes adversarias generativas condicionales". Para lograr el propósito de la generación condicional, concatenamos un vector de etiqueta y en la entrada de ruido z a la red generadora G, diciéndole a la red de generación que genere los datos. especificado por la etiqueta. Para la entrada de datos al discriminador D, también se usa una etiqueta como concat para decirle a la red discriminante que juzgue si la entrada son datos reales de esta categoría.

Entonces, la función objetivo de cGAN se puede expresar de la siguiente manera:

cGAN utiliza MLP como estructura de red, la entrada unidimensional se puede integrar fácilmente con el vector de etiquetas o concatenación de etiquetas, pero para el modelo de CNN convencional de tareas de generación de imágenes, este método de introducción no se puede adoptar directamente, especialmente para la red discriminadora.
La proyección GAN se encuentra a través de la derivación, suponiendo que
[oficial]
es  [oficial] la red discriminadora,  [oficial] la función de activación y  [oficial] el modelo de red. Bajo la función objetivo anterior, la óptima  [oficial] es
[oficial]

. De acuerdo con la fórmula anterior, la imagen de entrada primero pasa por la red para  [oficial] extrae características unidimensionales y luego las divide en dos canales. , todo el camino a través de la red para  [oficial] generar el resultado de discriminación sobre la autenticidad de la imagen, todo el camino para hacer el producto punto con la etiqueta de categoría codificada para obtener la discriminación result sobre la categoría, y luego sume los dos resultados para obtener el resultado de discriminación final. La estructura del modelo es la siguiente:

2.如何稳定训练过程

GAN 的训练过程不稳定,而 cGAN 中学习多种类别的数据,稳定训练过程则更具有挑战性。在之前的研究如 WGAN,WGANGP 中,学者们发现对网络施加 Lipschitz 约束能有效稳定 GAN 的训练过程。
相比于之前在损失函数中增加正则项的做法,SNGAN 提出了谱归一化 (spectral normalization) 来构造网络模型,使得无论网络参数是什么,都能满足 Lipschitz 约束。

3.如何提升生成质量

在多类别数据如 ImageNet 的训练过程中,人们发现网络更加擅长生成局部的细节纹理,如狗的毛发;而对于几何特征和整体结构,生成效果往往不尽如人意。如下图 SNGAN 生成结果为例,狗的身体结构存在很多错误和不完整的表达。

SAGAN 认为这是由于卷积模型难以捕捉到距离较远的特征,因此引入注意力机制,设计了如下的注意力模块。

这里卷积特征图经过三个 1x1 卷积 f(x), g(x), h(x),将 f(x) 的输出转置,并和 g(x) 的输出相乘,再经过 softmax 归一化得到一个 attention map,将得到的 attention map 和 h(x) 逐像素点相乘,再经过卷积 v(x) 得到自适应注意力的特征图。
SAGAN 在生成器和判别器中引入了这个注意力模块,使得生成器可以建模图像跨区域的依赖关系,判别器可以对全局图像结构施加几何约束。

集大成者:BigGAN

在采用上面提到的投影判别,谱归一化,自注意模块的基础上,《LARGE SCALE GAN TRAINING FOR
HIGH FIDELITY NATURAL IMAGE SYNTHESIS》 (即 BigGAN )通过增大 batch size,提升模型宽度,显著提升了图像生成的结果。(见下图,我们在上篇文章中提到了 FID 和 IS 两个评价生成模型的 metrics,可以看到,通过提升 batch size 和 通道数,这两个 metric 有显著提升)

实际训练中,我们在有限的显卡资源下,很难直接采用 2048 的 batchsize,
为了让 latent 能够直接影响不同分辨率上的特征,BigGAN 提出了 skip-z ,将 noise 分割后,分别传入 G 网络不同尺度的 layer。为了节省时间和内存,BigGAN 只对 class 做一次 embedding,将编码结果传给每个条件批归一化。以下为一个 BigGAN 的 G 网络结构图。

其中每一个 ResBlock 结构如下。分组的噪声和类别嵌入 concat 后,通过线性层得到每个 BatchNorm 的 gain 和 bias 参数,通过这种方式引入类别信息。

使用 MMGeneration 上手 BigGAN

在我们的文章 PyTorch 零基础入门 GAN 模型之基础篇 中,我们介绍了如何安装 MMGen 和训练模型。
在此基础上,我们可以上手以 BigGAN 为代表的条件生成模型。
我们可以先看看 BigGAN 生成的图片长啥样,通过运行如下代码,我们可以从预训练好的 BigGAN 中 sample 类别随机的图片。

python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth 
复制代码

当然,我们也可以用 --label 来指定采样的类别,--samples-per-classes 来指定每类采样的数量。
比如运行下面代码 :

python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --label 151 285 292 --samples-per-classes 5 
复制代码

可以得到 狗(151),猫(285),老虎(292)各五张图片。

我们还可以通过运行下面的代码看看 BigGAN 分别从噪声空间和标签空间插值的结果。

python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-z # 固定噪声 
 
python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-y # 固定标签 
复制代码


结果分别如下 :\

现在我们可以看看怎么训练 BigGAN,首先我们需要下载 ImageNet,然后放到 ./data 文件夹下。ImageNet 的下载方法可以参考这篇知乎分享 。


BigGAN 的训练有几个关键点设置,我们以 configs/base/models/biggan/biggan_128x128.py 和
configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 为例进行说明。
首先是 spectral normalization 的实现方式,我们提供了两种实现方式,一种是 PyTorch 官方提供的实现,另一种是 BigGAN 作者 ajbrock 提供的实现。我们可以设置 sn_style 为 torch 或者 ajbrock 来选择,如果不设置,默认为 ajbrock。

# 使用 BigGAN 作者提供的 SN 实现 
model = dict( 
    type='BasiccGAN', 
    generator=dict(xxx, sn_style='ajbrock'), 
    discriminator=dict(xxx, sn_style='ajbrock'), 
    gan_loss=dict(type='GANLoss', gan_type='hinge')) 
复制代码

在生成模型中,有时希望用一个更强的 D 来引导 G 的更新。常用的一种设置为训练若干步判别器,再训练一步生成器 ,这里 BigGAN 通过设置 train_cfg 中的 disc_steps 和 gen_steps 来实现。

train_cfg = dict( 
    disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True) 
复制代码

这个 config 表示训练 8 步判别器,再训练一步生成器。
在上段代码中,字段 batch_accumulation_steps 涉及到梯度累积操作,因为显存限制,我们很难直接在 batchsize 为 2048 的数据上做 forward 和 backward,因此可以将多个小批量上的梯度平均化,只在 batch_accumulation_steps 次累积后进行优化。假设我们在 8 卡上训练,每张卡 batchsize 为 32,batch_accumulation_steps = 8,这样可以逼近 8832 = 2048 的 batchsize。
为了稳定 GAN 的训练过程,我们往往要使用一种叫指数移动平均的技巧。通过设置 generator 网络的备份 generator_ema 。 在每个 train iter 后,将更新的模型参数和历史参数加权平均后作为generator_ema 的参数,这样 generator_ema 的参数更新会比 generator 更加平滑,作为训练结束后的 inference 模型,其生成结果更好。
为了使用 ema,首先需要在上文 train_cfg 中将 use_ema 设置为 True,同时,需要在 config 中添加一个 ExponentialMovingAverageHook。

custom_hooks = [ 
    xxx, 
    dict( 
        type='ExponentialMovingAverageHook', 
        module_keys=('generator_ema', ), 
        interval=8, 
        start_iter=160000, 
        interp_cfg=dict(momentum=0.9999, momentum_nontrainable=0.9999), 
        priority='VERY_HIGH') 
] 
复制代码

Aquí interval es la frecuencia de actualización del parámetro generator_ema, start_inter es la iteración de inicio de ema, y ​​antes de eso, los parámetros de red se copian del generador. Momentum en interp_cfg es el peso actualizado de los parámetros, y momentum_nontrainable es el peso actualizado de los búferes.
Para otros detalles e implementación específica del entrenamiento del modelo, puede  consultar el código en MMGeneration  (por supuesto, también podemos emitir una explicación detallada ~)
La configuración anterior, ya la hemos escrito para usted en la configuración, solo ejecute el siguiente código ¡Puedes comenzar a entrenar tu propio modelo BigGAN!

bash tools/dist_train.sh configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 8 --work-dir ./work_dirs/biggan 
复制代码

Durante el proceso de entrenamiento, puede ver las imágenes generadas por el modelo en diferentes etapas en work_dirs/biggan/training_samples. Las primeras cuatro líneas son generator_ema para generar imágenes, y las últimas cuatro líneas son las imágenes generadas por el generador con la misma entrada.

De hecho, el muestreo condicional, la interpolación condicional y el entrenamiento de modelos anteriores   también son aplicables a SNGAN y SAGAN que han sido compatibles con MMGeneration. Puede usarlo y hacer comentarios en cualquier momento ~ Gracias

github.com/open-mmlab/…

Supongo que te gusta

Origin juejin.im/post/7083427895216963597
Recomendado
Clasificación