An intuitive introduction to Generative Adversarial Networks (GANs)

生成对抗网络(GAN)的直观介绍

  • Warm up

让我们说你附近有一个非常酷的派对,你真的想去。 但有个问题。 要进入聚会,你需要一张特价机票 - 这个机票已经售罄。

等等! 这不是Generative Adversarial Networks的文章吗? 是的。 但是现在忍受我,这是值得的。

好吧,由于期望很高,派对的组织者聘请了一个合格的安全机构。 他们的主要目标是不允许任何人崩溃。 为此,他们在场地入口处放了很多警卫,检查每个人的真实性门票。

由于你没有任何武术艺术礼品,唯一的办法就是用一张非常令人信服的假票欺骗他们。

这个计划存在一个很大的问题 - 你从未真正看到这张票的样子。
即使你根据自己的创造力设计了一张罚单,在第一次试用时几乎不可能愚弄守卫。 此外,在你有一个非常体面的派对的传球复制品之前,你不能露面。

为了帮助解决问题,您决定打电话给您的朋友Bob为您完成工作。

鲍勃的使命非常简单。 他会尝试用假通行证进入派对。 如果他被拒绝,他会回复你,提供关于机票应该如何的有用提示。

根据这些反馈,您可以制作新版本的故障单并将其交给Bob,后者再次尝试。 这个过程不断重复,直到您能够设计出完美的副本。

在这里插入图片描述
暂且不谈这个轶事中的“小漏洞”,这几乎就是Generative Adversarial Networks(GAN)的工作原理。

如今,GAN的大多数应用都属于计算机视觉领域。 一些应用包括训练半监督分类器,以及从低分辨率对应物生成高分辨率图像。

本文介绍了GAN,并采用实际操作方法来解决生成图像的问题。 您可以在此处克隆此帖子的笔记本。

  • Generative Adversarial Networks
    在这里插入图片描述
    Generative Adversarial Network framework.

GAN是Goodfellow等人设计的生成模型。 在GAN设置中,由神经网络表示的两个可区分函数被锁定在游戏中。 两个参与者(生成器和鉴别器)在此框架中具有不同的角色。

生成器尝试生成来自某些概率分布的数据。 那就是你试图重现派对的门票。

鉴别者就像一个法官。 它决定输入是来自发生器还是来自真正的训练集。 这将是派对的安全性,将您的假票与真实票证进行比较,以找到您设计中的缺陷。

  • In summary, the game follows with:
    试图最大化使鉴别器错误输入的概率的发生器是真实的。
    并且鉴别器引导发生器产生更逼真的图像。
    在完美均衡中,发电机将捕获一般的训练数据分布。 结果,鉴别器总是不确定其输入是否真实。

在这里插入图片描述
改编自DCGAN论文。 这里实现了Generator网络。 注意完全连接和池化层不存在。

在DCGAN论文中,作者描述了一些深度学习技术的组合作为训练GAN的关键。 这些技术包括:(i)所有卷积网和(ii)批量标准化(BN)。

第一个强调两个方面的跨步卷积(而不是汇集层):增加和减少特征的空间维度。 并且第二个将特征向量归一化以在所有层中具有零均值和单位方差。 这有助于稳定学习并处理重量不足的初始化问题。

不用多说,让我们深入了解实施细节,并在我们开始时更多地讨论GAN。 我们提出了深度卷积生成对抗网络(DCGAN)的实现。 我们的实现使用Tensorflow并遵循DCGAN文件中描述的一些实践。

  • Generator

该网络有4个卷积层,后面都是BN(输出层除外)和整流线性单元(ReLU)激活。

它将随机向量z(从正态分布中绘制)作为输入。在重塑z以获得4D形状后,我们将其提供给启动一系列上采样层的生成器。

每个上采样层表示具有步幅2的转置卷积运算。转置卷积类似于常规卷积。

通常,常规卷积从宽层和浅层变为更窄和更深层。转置卷积是另一种方式。它们从深而窄的层变为更宽更浅。

转置卷积运算的步幅定义了输出层的大小。使用“相同”填充和2的步幅,输出功能将具有输入层的两倍大小。

之所以发生这种情况,是因为每当我们在输入层中移动一个像素时,我们就会在输出层上将卷积内核移动两个像素。换句话说,输入图像中的每个像素用于在输出图像中绘制正方形。

在这里插入图片描述

在2x2输入上使用步幅2对3x3内核进行转置,相当于在5x5输入上使用步幅2对3x3内核进行卷积。对于两者,不使用填充“VALID”。

简而言之,生成器从这个非常深但很窄的输入向量开始。 在每个转置卷积之后,z变得更宽和更浅。 所有转置卷积使用5x5内核的大小,深度从512一直减少到3 - 表示RGB彩色图像。

def transpose_conv2d(x, output_space):
    return tf.layers.conv2d_transpose(x, output_space, 
      kernel_size=5, strides=2, padding='same',
      kernel_initializer=tf.random_normal_initializer(mean=0.0,
                                                      stddev=0.02))

最后一层输出一个32x32x3张量 - 通过双曲正切(tanh)函数在-1和1之间压扁。

最终输出形状由训练图像的大小定义。 在这种情况下,如果训练SVHN,则生成器生成32x32x3图像。 但是,如果训练MNIST,它将生成28x28灰度图像。

最后,请注意,在将输入向量z馈送到生成器之前,我们需要将其缩放到-1到1的间隔。这是遵循使用tanh函数的选择。

def generator(z, output_dim, reuse=False, alpha=0.2, training=True):
    """
    Defines the generator network
    :param z: input random vector z
    :param output_dim: output dimension of the network
    :param reuse: Indicates whether or not the existing model variables should be used or recreated
    :param alpha: scalar for lrelu activation function
    :param training: Boolean for controlling the batch normalization statistics
    :return: model's output
    """
    with tf.variable_scope('generator', reuse=reuse):
        fc1 = dense(z, 4*4*512)

        # Reshape it to start the convolutional stack
        fc1 = tf.reshape(fc1, (-1, 4, 4, 512))
        fc1 = batch_norm(fc1, training=training)
        fc1 = tf.nn.relu(fc1)

        t_conv1 = transpose_conv2d(fc1, 256)
        t_conv1 = batch_norm(t_conv1, training=training)
        t_conv1 = tf.nn.relu(t_conv1)

        t_conv2 = transpose_conv2d(t_conv1, 128)
        t_conv2 = batch_norm(t_conv2, training=training)
        t_conv2 = tf.nn.relu(t_conv2)

        logits = transpose_conv2d(t_conv2, output_dim)

        out = tf.tanh(logits)
        return out
  • Discriminator
    鉴别器也是具有BN(除了其输入层)和泄漏的ReLU激活的4层CNN。 使用这种基本的GAN架构,许多激活功能都可以正常工作。 但是,泄漏的ReLU非常受欢迎,因为它们有助于梯度更容易地通过架构流动。

常规ReLU功能通过将负值截断为0来工作。这具有阻止梯度流过网络的效果。 泄漏的ReLU允许小的负值通过,而不是函数为零。 也就是说,该函数计算特征与小因子之间的最大值。

def lrelu(x, alpha=0.2):
     # non-linear activation function
    return tf.maximum(alpha * x, x)

Leaky ReLUs代表了解决垂死的ReLU问题的尝试。 当神经元陷入ReLU单元总是为所有输入输出0的状态时,就会发生这种情况。 对于这些情况,梯度完全关闭以通过网络回流。

这对于GAN尤其重要,因为生成器必须学习的唯一方法是从鉴别器接收梯度。

在这里插入图片描述
在这里插入图片描述
鉴别器从接收32x32x3图像张量开始。 与发生器相反,鉴别器执行一系列跨步的2次卷积。 每个都通过将特征向量的空间维度减小一半的大小来工作,也使得学习过滤器的数量加倍。

最后,鉴别器需要输出概率。 为此,我们在最终的logits上使用Logistic Sigmoid激活函数。

def discriminator(x, reuse=False, alpha=0.2, training=True):
    """
    Defines the discriminator network
    :param x: input for network
    :param reuse: Indicates whether or not the existing model variables should be used or recreated
    :param alpha: scalar for lrelu activation function
    :param training: Boolean for controlling the batch normalization statistics
    :return: A tuple of (sigmoid probabilities, logits)
    """
    with tf.variable_scope('discriminator', reuse=reuse):
        # Input layer is 32x32x?
        conv1 = conv2d(x, 64)
        conv1 = lrelu(conv1, alpha)

        conv2 = conv2d(conv1, 128)
        conv2 = batch_norm(conv2, training=training)
        conv2 = lrelu(conv2, alpha)

        conv3 = conv2d(conv2, 256)
        conv3 = batch_norm(conv3, training=training)
        conv3 = lrelu(conv3, alpha)

        # Flatten it
        flat = tf.reshape(conv3, (-1, 4*4*256))
        logits = dense(flat, 1)

        out = tf.sigmoid(logits)
        return out, logits

请注意,在此框架中,鉴别器充当常规二进制分类器。 一半时间它从训练集接收图像而另一半从发生器接收图像。

回到我们的冒险之旅,重现派对的门票,你所拥有的唯一信息来源是我们的朋友鲍勃的反馈。 换句话说,Bob在每次试用时提供给您的反馈质量对于完成工作至关重要。

以同样的方式,每当鉴别器注意到真实和假图像之间的差异时,它就向发生器发送信号。 该信号是从鉴别器向发生器向后流动的梯度。 通过接收它,发生器能够调整其参数以更接近真实的数据分布。

这是鉴别器的重要性。 事实上,发生器将与产生数据一样好,因为鉴别器正在分辨它们。

  • Losses

现在,让我们描述一下这种架构中最棘手的部分 - 损失。 首先,我们知道鉴别器从训练集和生成器接收图像。

我们希望鉴别器能够区分真实和假图像。 每次我们通过鉴别器运行一个小批量时,我们都会得到logits。 这些是模型中未缩放的值。

但是,我们可以将鉴别器接收的小批量分为两种类型。 第一个,仅由来自训练集和第二个的真实图像组成,仅包含假图像 - 由生成器创建的图像。

def model_loss(input_real, input_z, output_dim, alpha=0.2, smooth=0.1):
    """
    Get the loss for the discriminator and generator
    :param input_real: Images from the real dataset
    :param input_z: random vector z
    :param out_channel_dim: The number of channels in the output image
    :param smooth: label smothing scalar
    :return: A tuple of (discriminator loss, generator loss)
    """
    g_model = generator(input_z, output_dim, alpha=alpha)
    d_model_real, d_logits_real = discriminator(input_real, alpha=alpha)

    d_model_fake, d_logits_fake = discriminator(g_model, reuse=True, alpha=alpha)

    # for the real images, we want them to be classified as positives,  
    # so we want their labels to be all ones.
    # notice here we use label smoothing for helping the discriminator to generalize better.
    # Label smoothing works by avoiding the classifier to make extreme predictions when extrapolating.
    d_loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real) * (1 - smooth)))

    # for the fake images produced by the generator, we want the discriminator to clissify them as false images,
    # so we set their labels to be all zeros.
    d_loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_model_fake)))

    # since the generator wants the discriminator to output 1s for its images, it uses the discriminator logits for the
    # fake images and assign labels of 1s to them.
    g_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_model_fake)))

    d_loss = d_loss_real + d_loss_fake

    return d_loss, g_loss

由于两个网络同时训练,GAN还需要两个优化器。 每一个都分别用于最小化鉴别器和发生器的损耗函数。

我们希望鉴别器为真实图像输出接近1的概率,对于假图像输出接近0的概率。 为此,鉴别器需要两次损失。 因此,鉴别器的总损失是这两个部分损失的总和。 一个用于最大化真实图像的概率,另一个用于最小化伪图像的概率。

在这里插入图片描述
比较真实(左)和生成(右)SVHN样本图像。 虽然有些图像看起来很模糊而有些图像难以识别,但很明显数据分布是由模型捕获的。

在训练开始时,会出现两种有趣的情况。 首先,生成器不知道如何创建类似于训练集中的图像的图像。 第二,鉴别器不知道如何将它接收的图像分类为真实或假的。

结果,鉴别器接收两种非常不同类型的批次。 一个由训练集的真实图像组成,另一个包含非常嘈杂的信号。 随着训练的进行,发生器开始输出看起来更接近训练集图像的图像。 发生这种情况,因为发电机训练以学习组成训练集图像的数据分布。

与此同时,鉴别者开始真正善于将样本分类为真实或假冒。 因此,两种类型的小批量开始在结构上看起来彼此相似。 结果,这使得鉴别器无法将图像识别为真实或假的。

对于损失,我们使用vanilla交叉熵与Adam作为优化器的良好选择。

在这里插入图片描述
比较实际(左)和生成(右)MNIST样本图像。 由于MNIST图像具有更简单的数据结构,因此与SVHN相比,该模型能够生成更逼真的样本。

  • Concluding

GAN是目前机器学习中最热门的课程之一。 这些模型有可能解开无监督的学习方法,将ML扩展到新的视野。

自创建以来,研究人员一直在开发许多训练GAN的技术。 在用于训练GAN的改进技术中,作者描述了用于图像生成和半监督学习的最先进技术。

猜你喜欢

转载自blog.csdn.net/weixin_41697507/article/details/87895537