PyTorch zero-based entry GAN model cGAN

Background introduction

In previous articles, we introduced the principles of GANs and how to evaluate trained models . Some of you may have seen how the generated images are of a single category. For example, CIFAR10 and ImageNet contain multiple categories of images. If I want to train a generative adversarial network that can generate multiple categories of images, what should I do? Woolen cloth?
Then models called conditional GANs (cGANs) can come in handy.

Development of cGAN

1. How to introduce category information

The first cGAN was proposed in the paper "Conditional Generative Adversarial Nets". In order to achieve the purpose of conditional generation, we concat a label vector y on the noise z input to the generator network G, telling the generation network to generate the data specified by the label. For the data input to the discriminator D, a label such as concat is also used to tell the discriminant network to judge whether the input is real data of this category.

Then, the objective function of cGAN can be expressed as follows:

cGAN uses MLP as the network structure, one-dimensional input can be easily embedded with label vector or label concat, but for the mainstream CNN model of image generation tasks, this introduction method cannot be directly adopted, especially for the discriminator network.
Projection GAN is found by derivation, assuming that
[official]
it  [official] is the discriminator network,  [official] the activation function, and  [official] the network model. Under the above objective function, the optimal one  [official] is
[official]

. According to the above formula, the input image first goes through the network to  [official] extract one-dimensional features, and then divides it into two channels. , all the way through the network to  [official] output the discrimination result about the authenticity of the image, all the way to do the dot product with the encoded category label to get the discrimination result about the category, and then add the two results to get the final discrimination result. The model structure is as follows:

2.如何稳定训练过程

GAN 的训练过程不稳定,而 cGAN 中学习多种类别的数据,稳定训练过程则更具有挑战性。在之前的研究如 WGAN,WGANGP 中,学者们发现对网络施加 Lipschitz 约束能有效稳定 GAN 的训练过程。
相比于之前在损失函数中增加正则项的做法,SNGAN 提出了谱归一化 (spectral normalization) 来构造网络模型,使得无论网络参数是什么,都能满足 Lipschitz 约束。

3.如何提升生成质量

在多类别数据如 ImageNet 的训练过程中,人们发现网络更加擅长生成局部的细节纹理,如狗的毛发;而对于几何特征和整体结构,生成效果往往不尽如人意。如下图 SNGAN 生成结果为例,狗的身体结构存在很多错误和不完整的表达。

SAGAN 认为这是由于卷积模型难以捕捉到距离较远的特征,因此引入注意力机制,设计了如下的注意力模块。

这里卷积特征图经过三个 1x1 卷积 f(x), g(x), h(x),将 f(x) 的输出转置,并和 g(x) 的输出相乘,再经过 softmax 归一化得到一个 attention map,将得到的 attention map 和 h(x) 逐像素点相乘,再经过卷积 v(x) 得到自适应注意力的特征图。
SAGAN 在生成器和判别器中引入了这个注意力模块,使得生成器可以建模图像跨区域的依赖关系,判别器可以对全局图像结构施加几何约束。

集大成者:BigGAN

在采用上面提到的投影判别,谱归一化,自注意模块的基础上,《LARGE SCALE GAN TRAINING FOR
HIGH FIDELITY NATURAL IMAGE SYNTHESIS》 (即 BigGAN )通过增大 batch size,提升模型宽度,显著提升了图像生成的结果。(见下图,我们在上篇文章中提到了 FID 和 IS 两个评价生成模型的 metrics,可以看到,通过提升 batch size 和 通道数,这两个 metric 有显著提升)

实际训练中,我们在有限的显卡资源下,很难直接采用 2048 的 batchsize,
为了让 latent 能够直接影响不同分辨率上的特征,BigGAN 提出了 skip-z ,将 noise 分割后,分别传入 G 网络不同尺度的 layer。为了节省时间和内存,BigGAN 只对 class 做一次 embedding,将编码结果传给每个条件批归一化。以下为一个 BigGAN 的 G 网络结构图。

其中每一个 ResBlock 结构如下。分组的噪声和类别嵌入 concat 后,通过线性层得到每个 BatchNorm 的 gain 和 bias 参数,通过这种方式引入类别信息。

使用 MMGeneration 上手 BigGAN

在我们的文章 PyTorch 零基础入门 GAN 模型之基础篇 中,我们介绍了如何安装 MMGen 和训练模型。
在此基础上,我们可以上手以 BigGAN 为代表的条件生成模型。
我们可以先看看 BigGAN 生成的图片长啥样,通过运行如下代码,我们可以从预训练好的 BigGAN 中 sample 类别随机的图片。

python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth 
复制代码

当然,我们也可以用 --label 来指定采样的类别,--samples-per-classes 来指定每类采样的数量。
比如运行下面代码 :

python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --label 151 285 292 --samples-per-classes 5 
复制代码

可以得到 狗(151),猫(285),老虎(292)各五张图片。

我们还可以通过运行下面的代码看看 BigGAN 分别从噪声空间和标签空间插值的结果。

python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-z # 固定噪声 
 
python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-y # 固定标签 
复制代码


结果分别如下 :\

现在我们可以看看怎么训练 BigGAN,首先我们需要下载 ImageNet,然后放到 ./data 文件夹下。ImageNet 的下载方法可以参考这篇知乎分享 。


BigGAN 的训练有几个关键点设置,我们以 configs/base/models/biggan/biggan_128x128.py 和
configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 为例进行说明。
首先是 spectral normalization 的实现方式,我们提供了两种实现方式,一种是 PyTorch 官方提供的实现,另一种是 BigGAN 作者 ajbrock 提供的实现。我们可以设置 sn_style 为 torch 或者 ajbrock 来选择,如果不设置,默认为 ajbrock。

# 使用 BigGAN 作者提供的 SN 实现 
model = dict( 
    type='BasiccGAN', 
    generator=dict(xxx, sn_style='ajbrock'), 
    discriminator=dict(xxx, sn_style='ajbrock'), 
    gan_loss=dict(type='GANLoss', gan_type='hinge')) 
复制代码

在生成模型中,有时希望用一个更强的 D 来引导 G 的更新。常用的一种设置为训练若干步判别器,再训练一步生成器 ,这里 BigGAN 通过设置 train_cfg 中的 disc_steps 和 gen_steps 来实现。

train_cfg = dict( 
    disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True) 
复制代码

这个 config 表示训练 8 步判别器,再训练一步生成器。
在上段代码中,字段 batch_accumulation_steps 涉及到梯度累积操作,因为显存限制,我们很难直接在 batchsize 为 2048 的数据上做 forward 和 backward,因此可以将多个小批量上的梯度平均化,只在 batch_accumulation_steps 次累积后进行优化。假设我们在 8 卡上训练,每张卡 batchsize 为 32,batch_accumulation_steps = 8,这样可以逼近 8832 = 2048 的 batchsize。
为了稳定 GAN 的训练过程,我们往往要使用一种叫指数移动平均的技巧。通过设置 generator 网络的备份 generator_ema 。 在每个 train iter 后,将更新的模型参数和历史参数加权平均后作为generator_ema 的参数,这样 generator_ema 的参数更新会比 generator 更加平滑,作为训练结束后的 inference 模型,其生成结果更好。
为了使用 ema,首先需要在上文 train_cfg 中将 use_ema 设置为 True,同时,需要在 config 中添加一个 ExponentialMovingAverageHook。

custom_hooks = [ 
    xxx, 
    dict( 
        type='ExponentialMovingAverageHook', 
        module_keys=('generator_ema', ), 
        interval=8, 
        start_iter=160000, 
        interp_cfg=dict(momentum=0.9999, momentum_nontrainable=0.9999), 
        priority='VERY_HIGH') 
] 
复制代码

Here interval is the update frequency of the generator_ema parameter, start_inter is the ema start iteration, and before that, the network parameters are copied from the generator. Momentum in interp_cfg is the updated weight of parameters, and momentum_nontrainable is the updated weight of buffers.
For other details and specific implementation of model training, you can refer  to the code in MMGeneration  (of course, we may also issue a detailed explanation~)
The above settings have been written for you in config, just run the following code, You can start training your own BigGAN model!

bash tools/dist_train.sh configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 8 --work-dir ./work_dirs/biggan 
复制代码

During the training process, you can view the images generated by the model at different stages under work_dirs/biggan/training_samples. The first four lines are generator_ema to generate pictures, and the last four lines are the pictures generated by the generator with the same input.

In fact, the above conditional sampling, conditional interpolation, and model training   are also applicable to SNGAN and SAGAN that have been supported in MMGeneration. You are welcome to use it and make comments at any time~ Thanks

github.com/open-mmlab/…

Guess you like

Origin juejin.im/post/7083427895216963597