【GAN】【论文笔记】A Style-Based Generator Architecture for Generative Adversarial Networks

1.论文工作

论文工作可以参考【StyleGAN-基于样式的生成对抗网络

2.代码

代码参考styleGAN.py - facebookresearch/pytorch_GAN_zoo - GitHub1s

Trainer简单结构如下(有一定省略):

2.1 映射网络

2.1.1 网络结构

由8个MLP组成,其输出w与输入z大小相同(512×1)

class MappingLayer(nn.Module):
#dimInput:512 dimLatent:512 nLayers:8
    def __init__(self, dimIn, dimLatent, nLayers, leakyReluLeak=0.2):
        super(MappingLayer, self).__init__()
        self.FC = nn.ModuleList()

        inDim = dimIn
        for i in range(nLayers):
            self.FC.append(EqualizedLinear(inDim, dimLatent, lrMul=0.01, equalized=True, initBiasToZero=True))
            inDim = dimLatent

        self.activation = torch.nn.LeakyReLU(leakyReluLeak)

    def forward(self, x):
        for layer in self.FC:
            x = self.activation(layer(x))

2.1.2 映射网络的输入

    def buildNoiseData(self, n_samples, inputLabels=None):
        inputLatent = torch.randn(
            n_samples, self.config.noiseVectorDim).to(self.device)
        return inputLatent, None
    #根据噪声维度生成随机输入


    inputLatent, targetRandCat = self.buildNoiseData(n_samples)

2.1.3 截断技巧

使用截断技巧避免了从W的极端区域进行采样

计算方法:

  1. 计算质心
  2. 将给定w与中心的偏差缩放为w' =\overline{ w} +\psi \left ( w-\overline{w} \right )
        if self.training:
            self.mean_w = self.gamma_avg * self.mean_w + (1-self.gamma_avg) * mapping.mean(dim=0, keepdim=True)

        if self.phiTruncation < 1:#0.999
            mapping = self.mean_w + self.phiTruncation * (mapping - self.mean_w)

2.2 AdaIN

(虽然在StyleGAN2就发现AdaIN会引入“水滴”问题)

  • ①首先每个特征图xi(feature map)独立进行归一化\frac{\left ( x_i - \mu (x_i ) \right )}{\sigma(x)) }。特征图中的每个值减去该特征图的均值然后除以方差
  • ②一个可学习的仿射变换A(全连接层)将w转化为style中AdaIN的平移和缩放因子y =(ys,i,yb,i),
  • ③然后对每个特征图分别使用style中学习到的的平移和缩放因子进行尺度和平移变换。
class AdaIN(nn.Module):

    def __init__(self, dimIn, dimOut, epsilon=1e-8):
        super(AdaIN, self).__init__()
        self.epsilon = epsilon
        self.styleModulator = EqualizedLinear(dimIn, 2*dimOut, equalized=True,
                                              initBiasToZero=True)
        self.dimOut = dimOut

    def forward(self, x, y):

        # x: N x C x W x H
        batchSize, nChannel, width, height = x.size()
        tmpX = x.view(batchSize, nChannel, -1)
        mux = tmpX.mean(dim=2).view(batchSize, nChannel, 1, 1)
        varx = torch.clamp((tmpX*tmpX).mean(dim=2).view(batchSize, nChannel, 1, 1) - mux*mux, min=0)
        varx = torch.rsqrt(varx + self.epsilon)
        x = (x - mux) * varx

        # Adapt style
        styleY = self.styleModulator(y)
        yA = styleY[:, : self.dimOut].view(batchSize, self.dimOut, 1, 1)
        yB = styleY[:, self.dimOut:].view(batchSize, self.dimOut, 1, 1)

        return yA * x + yB

2.3 synthesis network

2.3.1  输入常数

输入了一个随机生成的可学习参数作为初始量

self.baseScale0 = nn.Parameter(torch.ones(1, dimMapping, 4, 4), requires_grad=True)

2.3.2 随机变化

 StyleGAN的架构通过在每次卷积后添加噪声来实现生成图像在空间上有一些随机的(细节)变化。

            noiseMod = self.noiseModulators[nLayer]
            feature = Upscale2d(feature)
                        #卷积
            feature = group[0](feature) + noiseMod[0](torch.randn((batchSize, 1,
                                                      feature.size(2),
                                                      feature.size(3)), device=x.device))
            feature = self.activation(feature)
                       #AdaIN
            feature = group[1](feature, mapping)
                        #卷积
            feature = group[2](feature) + noiseMod[1](torch.randn((batchSize, 1,
                                                      feature.size(2),
                                                      feature.size(3)), device=x.device))
            feature = self.activation(feature)
                        #AdaIN
            feature = group[3](feature, mapping)

2.3.3 噪声生成

class NoiseMultiplier(nn.Module):

    def __init__(self):
        super(NoiseMultiplier, self).__init__()
        self.module = nn.Conv2d(1, 1, 1, bias=False)
        self.module.weight.data.fill_(0)

    def forward(self, x):

        return self.module(x)

猜你喜欢

转载自blog.csdn.net/weixin_50862344/article/details/131588020