mm中的GAN模型架构

mmgeneration中的GAN的架构包括,输入一张真实图像,如果是条件gan的话,输入的是一张真实图像和相应的标签,

1.首先对判别器进行训练,对判别器中梯度在训练时保存,和对判别器优化器进行zero_grad(),gan中生成器和判别器各有一个优化器,其实是两个相对独立的训练过程。

2.在训练判别器时也先使用生成器生成一个假样本,注意此时的生成器是不训练,仅仅是生成一个假样本,在条件gan时,还存在一个batch_accumulation_step的步骤,它在生成器训练时会多训练几次。

3.用判别器分别对真样本和假样本进行判别,计算判别器训练时的损失,正常的ganloss以及相应的辅助损失,放在disc_auxiliary_loss中,感觉gan中能够扩展的方向并不多,从架构上讲是transformers或者self-attention之类的,在理论上大体就是损失函数,l约束上的一些操作,所以此时的gen_auxiliary_loss是不同的一个点。判别器的loss本质上是个二分类,负样本输入,给的标签是0,正样本输入给的标签是1.

4.计算完损失之后,loss_disc.bachward()反向传播,optimizer['disciminator'].step()梯度更新,到这里判别器就训练完成了。

5.进行5-8轮的判别器训练之后开始转向生成器训练,每次训练前其实都对梯度进行了清零,相当于在每一轮时并不进行梯度累计,首先对判别器的梯度不进行保存,注意一开始训练判别器时,判别器的梯度是保存的,然后对生成器梯度进行zero_grad()。

6.训练生成器,生成器的输入还是噪声,得到生成器的图像之后,进行一次判别器判定,判别器最后一层一般是个Linear(n,1)的层,也就是输出是个N,1维的,输入到gen_loss中,这里gen_loss中核心也是辅助损失gen_auxiliary_loss中,此处也可能添加一些l约束之类的。

7.计算完损失之后,loss_gen.backward()反向传播,optimizer['generator'].step()梯度更新,生成器训练完成。

上图是一个典型的生成器和判别器的结构,在生成器中,我们需要一个将噪声向量转换成为二维特征的模块,也就是noise2feat block。接下来需要连续经过几个上采样块将低分辨率的特征转成高分辨率的特征,在 DCGAN 中,我们使用的是 transposed convolution 来实现。最后,需要一个 to_rgb 块来将特征图的通道数映射为3通道,从而生成图片。那判别器其实就是生成器的一个反转,我们需要通过 img2feat 和大量的下采样块将特征图不断降低分辨率,最后输送给 decision head,来对当前的输入图片进行评判。

猜你喜欢

转载自blog.csdn.net/u012193416/article/details/125540031
mm
GAN
今日推荐