AIGC actual combat——WGAN (Wasserstein GAN)

0. Preface

The original generative adversarial network (Generative Adversarial Network, GAN) faces mode collapse and gradient during the training process Disappearance and other problems, in order to solve these problems, researchers have proposed a large number of key technologies to improve the overall stability of the GAN model and reduce the possibility of the above problems. For example, WGAN (Wasserstein GAN) and WGAN-GP (Wasserstein GAN-Gradient Penalty) etc., by modifying the original generative adversarial network (Generative Adversarial Network, GAN) With slight adjustments to the framework, complex GANs can be trained. In this section, we will study WGAN and WGAN-GP, both of which make subtle adjustments to the original GAN ​​framework to improve the stability and quality of the image generation process .

1. WGAN-GP

WGAN (Wasserstein GAN) is a huge step forward in improving GAN training stability after some simple changesGAN It can achieve the following two characteristics:

  • Loss measure related to the convergence of the generator and the quality of the generated samples
  • The stability of the optimization process is improved

Specifically, WGAN proposes a new loss function (Wasserstein Loss) for the discriminator and generator, replacing Binary cross entropy can make the convergence of GAN more stable.
In this section, we will build a WGAN-GP (Wasserstein GAN-Gradient Penalty) using CelebA DatasetTrain the model to generate face images.

1.1 Wasserstein losses

First, let’s review the binary cross-contrast. This loss function is used when training the DCGAN discriminator and generator: /span>
− 1 n ∑ i = 1 n ( y i l o g ( p i ) + ( 1 − y i ) l o g ( 1 − p i ) ) -\frac 1 n \sum_{i=1}^n(y_ilog (p_i)+(1-y_i)log(1-p_i)) n1i=1n(yilog(pi)+(1andi)log(1pi))
To train the discriminator GAN D, we calculate the loss based on both: real images Prediction p i = D ( x i ) p_i=D(x_i) pi=D(xi) 用标签 y i = 1 y_i=1 andi=The error between 1, and the prediction of the generated image p i = D ( G ( z i ) ) p_i=D(G(z_i )) pi=D(G(zi))given y i = 0 y_i=0 andi=The error between 0. Therefore, for the discriminator of GAN, the process of minimizing the loss function can be expressed as:
min ⁡ D − ( E x ∼ p X [ log ⁡ D ( x ) ] + E z ∼ p Z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ) \mathop {\min} \limits_{D}-(\mathbb E_{x\sim p_X}[ \log D(x)]+\mathbb E_{z\sim p_Z}[\log (1-D(G(z)))]) Dmin(ExpX[logD(x)]+ANDzpZ[log(1D(G(z )))])
为了训练 GAN の generator G、My base generation image's ability < /span> p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) y i = 1 y_i=1 andi=The error calculation loss is 1. Therefore, for the generator of GAN, the process of minimizing the loss function can be expressed as:
min ⁡ G − ( E z ∼ p Z [ log ⁡ D ( G ( z ) ) ] ) \mathop {\min}\limits_{G}-(\mathbb E_{z\sim p_Z}[\log D(G(z))]) Gmin(EzpZ[logD(G(z ))])
Next, we compare the above loss function with the Wasserstein loss function.
Wasserstein loss (Wasserstein Loss) is a loss function for Wasserstein GAN (WGAN). Different from the traditional binary cross-entropy loss function, the Wasserstein loss introduces the labels 1 and -1, changing the output of the discriminator from The probability value is converted into a score (score), therefore, the discriminator of WGAN is also often called a critic (critic), And the discriminator is required to be a continuous function. Specifically, loss uses tag y i = 1 y_i=1 1-Lipschitz
Wassersteinandi=1 Sum y i = − 1 y_i=-1 andi=1 代替 y i = 1 y_i=1 andi=1 sum y i = 0 y_i=0 andi=0, and also need to remove the Sigmoid activation function of the last layer of the discriminator, so as to predict the result p i p_i pi is not necessarily in [ 0 , 1 ] [0,1] [0,1] is within the range, it can be [ − ∞ , ∞ ] [-∞ ,∞] [,] What's the name inside the fan? Wasserstein Determined as follows:
− 1 n ∑ i = 1 n ( y i p i ) -\frac 1 n∑_{i=1}^n(y_ip_i )n1i=1n(yipi)
When training the discriminator WGAN D , we will calculate the following loss: Prediction of image p i = D ( x i ) p_i=D(x_i) pi=D(xi) 用标签 y i = 1 y_i=1 andi=The error between 1, the discriminator’s prediction of the generated image p i = D ( G ( z i ) ) p_i=D(G (z_i)) pi=D(G(zi)) 用标签 y i = − 1 y_i=-1 andi=The error between −1. Therefore, for the WGAN discriminator, the process of minimizing the loss function can be expressed as:
min ⁡ D − ( E x ∼ p X [ D ( x ) ] − E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ D - (\mathbb E_{x\sim p_X}[D(x)] - \mathbb E_{z\sim p_Z}[D(G(z))]) Dmin(ExpX[D(x)]ANDzpZ[D(G( z))])
In other words, WGAN the discriminator tries to maximize its accuracy over real images The difference between the prediction and the prediction of the generated image, with the real image scoring higher.
And for the training of WGAN generator G, we base the prediction of the generated image on the discriminator p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) y i = 1 y_i=1 andi=1 Calculate the loss. Therefore, for the WGAN generator, the minimization loss function can be expressed as:
min ⁡ G − ( E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ G - (\mathbb E_{z\sim p_Z}[D(G(z))]) Gmin(EzpZ[D(G( z))])
In other words, the WGAN generator attempts to generate the discriminator with the Images with high scores are judged to be real images (i.e., causing the discriminator to think they are real).

1.2 Lipschitz constraint

Since we allow the discriminator output [ − ∞ , ∞ ] [-∞,∞] [,] instead of limiting the output to < as in the Sigmoid function /span> [ 0 , 1 ] [0,1] [0,1], so Wasserstein losses can be very large. Therefore, in order for the Wasserstein loss function to work properly, an additional constraint on the discriminator is required, namely the 1-Lipschitz continuity constraint. The discriminator is a function that converts an image into a prediction D, if for any two input images x 1 x_1 x1 sum x 2 x_2 x2, the discriminator functionD satisfies the following inequality, then the function is 1-Lipschitz continuous:
∣ D ( x 1 ) − D ( x 2 ) ∣ ∣ x 1 − x 2 ∣ ≤ 1 \frac {|D(x_1) - D(x_2)|}{|x_1 - x_2|} ≤ 1 x1x2D(x1)D(x2)1
inside, ∣ x 1 − x 2 ∣ |x_1 - x_2| x1x2 represents the absolute value of the difference between the average pixels of the two images, ∣ D ( x 1 ) − D ( x 2 ) ∣ |D(x_1) - D(x_2)| D(x1)D(x2) represents the absolute value between discriminator predictions. This means that the discriminator's predicted rate of change is bounded in any case (i.e. the absolute value of the gradient cannot be greater than 1). As you can see in the Lipschitz continuous one-dimensional function in the figure below, no matter where the cone is placed, the curve never goes inside the cone. In other words, the rate at which any point on the curve can rise or fall is finite.

Lipschitz Continuous

1.3 Enforcing Lipschitz constraints

In the originalWGAN paper, the author clipped the weight of the discriminator to a smaller range after each training session [ − 0.01 , 0.01 ] [-0.01, 0.01] [0.01,0.01] to enforce the Lipschitz constraint.
Since we clipped the weight of the discriminator, the learning ability of the discriminator is greatly reduced. Therefore, in fact, weight clipping is not an ideal forcing Lipschitz constraint The way. A strong discriminator is crucial to WGAN's success because without accurate gradients, the generator cannot learn how to adjust its weights to produce better samples.
Therefore, researchers have proposed many other methods to enforce Lipschitz constraints and improve WGAN the ability to learn complex features. One such method is with gradient penalty (Gradient Penalty). directly enforces the constraint by including a gradient penalty term in the discriminator's loss function if the gradient norm deviates from , this term will penalize the model, thereby making the training process more stable. Next, add this additional gradient penalty term to the discriminator loss function. Wasserstein GAN
Lipschitz1

1.4 Gradient penalty loss

The following figure showsWGAN-GP the training process of the discriminator. Compared with the training process of the original discriminator, we can see that the key improvement is to use the gradient penalty loss as the overall loss function part of , and used with Wasserstein losses from real and generated images.

WGAN-GP

The gradient penalty loss measures the squared difference between the gradient norm of the prediction with respect to the input image and 1 . The model tends to find weights that minimize the gradient penalty, thus encouraging the model to comply with the Lipschitz constraint.
During the training process, it is very difficult to calculate the gradient at every point, so WGAN-GP the gradient is only evaluated at a few points. To ensure balance, we use a set of interpolated images that are interpolated pixel by pixel (Interpolation) at random positions between the real and fake images to generate some images.

interpolated image

Use Keras to calculate the gradient penalty:

    def gradient_penalty(self, batch_size, real_images, fake_images):
        # 批数据中的每个图像都会得到一个 0~1 之间的随机数字,存储到向量 alpha 中
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        # 计算一组插值图像
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff
        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 使用判别器对每个插值图像进行评分
            pred = self.critic(interpolated, training=True)
        # 计算插值图像 (y_pred) 的预测对于输入 interpolated_samples) 的梯度
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 计算这个向量的 L2 范数(即欧几里得长度)
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        # 函数返回 L2 范数与 1 之差的平方的均值
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

1.5 Training WGAN-GP

One advantage of using Wasserstein loss function is that you no longer need to worry about balancing the training of the discriminator and generator. In fact, when using Wasserstein loss, the discriminator must be trained to convergence before updating the generator to ensure that the gradient of the generator update is accurate. This is in contrast to criterion GAN where it is important not to let the discriminator become too strong. Therefore, using we can simply train the discriminator multiple times between generator updates to ensure it is close to convergence. Typically the generator is updated once and the discriminator is updated three to five times. After understanding the two key concepts of ( loss and gradient penalty term), use Implementation:GAN
Wasserstein GAN
WGAN-GPWassersteinKerasWGAN-GP

    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]
        # 对判别器进行三次更新
        for i in range(self.critic_steps):
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )

            with tf.GradientTape() as tape:
                fake_images = self.generator(
                    random_latent_vectors, training=True
                )
                fake_predictions = self.critic(fake_images, training=True)
                real_predictions = self.critic(real_images, training=True)
                # 计算判别器的 Wasserstein 损失
                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)
                # 计算梯度惩罚项
                c_gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # 判别器损失函数是 Wasserstein 损失和梯度惩罚的加权和
                c_loss = c_wass_loss + c_gp * self.gp_weight
            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
            # 更新判别器的权重
            self.c_optimizer.apply_gradients(
                zip(c_gradient, self.critic.trainable_variables)
            )
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_latent_vectors, training=True)
            fake_predictions = self.critic(fake_images, training=True)
            # 计算生成器的 Wasserstein 损失
            g_loss = -tf.reduce_mean(fake_predictions)

        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # 更新生成器的权重
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )

        self.c_loss_metric.update_state(c_loss)
        self.c_wass_loss_metric.update_state(c_wass_loss)
        self.c_gp_metric.update_state(c_gp)
        self.g_loss_metric.update_state(g_loss)
        return {
    
    m.name: m.result() for m in self.metrics}

One final point to note before training WGAN-GP is that the discriminator should not use batch normalization. This is because batch normalization creates correlations between the same batch of images, making the gradient penalty loss less effective. Experiments show that even without batch normalization in the discriminator, WGAN-GP can still output excellent results.

2. Key differences between GAN and WGAN-GP

In summary, the following exists between the criteria GAN and WGAN-GP:

  • WGAN-GP Use Wasserstein Disappear
  • WGAN-GP uses 1 for real image tags and -1 for fake image tags
  • The last layer of the discriminator is not used sigmoid activation
  • Include a gradient penalty term in the discriminator’s loss function
  • Each time the generator is trained, the weights are updated and the discriminator needs to be trained multiple times.
  • There is no batch normalization layer in the discriminator

3. WGAN-GP model analysis

After training 25 epoch , the generator of the WGAN-GP model is able to generate reasonable images:

Face generation results

The model has learned important high-level features of the face with no signs of model collapse.
If we compare the output of WGAN-GP with the variational autoencoder (Variational Autoencoder, VAE) Comparing the output of , you can see that the image produced by WGAN-GP is generally sharper. Overall, VAE tends to produce images with blurry color boundaries, while GAN produces images that are more clearly defined. GAN is generally more difficult to train than VAE and takes longer to obtain satisfactory data quality.

summary

In this section, we learned how to use the Wasserstein loss function to solve the classic GAN problems such as mode collapse and vanishing gradient during training, Makes GAN training more predictable and reliable. WGAN-GP Impose a constraint on the training process by adding a term to the loss function that makes the gradient norm point to 1 . 1-Lipschitz

Series link

AIGC Practical Combat—Introduction to Generative Models
AIGC Practical Combat—Deep Learning (DL)
AIGC Practical Combat—Convolutional Neural Network (Convolutional Neural Network, CNN)
AIGC Practical Combat—Autoencoder (Autoencoder)
AIGC Practical Combat—Variational Autoencoder (VAE)
AIGC Practical Combat—Using variational autoencoders to generate facial images
AIGC Practical Combat—Generative Adversarial Network (GAN)

Guess you like

Origin blog.csdn.net/LOVEmy134611/article/details/133974577