AIGC in action - WGAN
0. Preface
The original generative adversarial network (Generative Adversarial Network
, GAN
) faces mode collapse and gradient during the training process Disappearance and other problems, in order to solve these problems, researchers have proposed a large number of key technologies to improve the overall stability of the GAN model and reduce the possibility of the above problems. For example, WGAN
(Wasserstein GAN
) and WGAN-GP
(Wasserstein GAN-Gradient Penalty
) etc., by modifying the original generative adversarial network (Generative Adversarial Network
, GAN
) With slight adjustments to the framework, complex GANs can be trained. In this section, we will study WGAN
and WGAN-GP
, both of which make subtle adjustments to the original GAN framework to improve the stability and quality of the image generation process .
1. WGAN-GP
WGAN (Wasserstein GAN
) is a huge step forward in improving GAN
training stability after some simple changesGAN
It can achieve the following two characteristics:
- Loss measure related to the convergence of the generator and the quality of the generated samples
- The stability of the optimization process is improved
Specifically, WGAN
proposes a new loss function (Wasserstein Loss
) for the discriminator and generator, replacing Binary cross entropy can make the convergence of GAN
more stable.
In this section, we will build a WGAN-GP
(Wasserstein GAN-Gradient Penalty
) using CelebA DatasetTrain the model to generate face images.
1.1 Wasserstein losses
First, let’s review the binary cross-contrast. This loss function is used when training the DCGAN discriminator and generator: /span>
− 1 n ∑ i = 1 n ( y i l o g ( p i ) + ( 1 − y i ) l o g ( 1 − p i ) ) -\frac 1 n \sum_{i=1}^n(y_ilog (p_i)+(1-y_i)log(1-p_i)) −n1i=1∑n(yilog(pi)+(1−andi)log(1−pi))
To train the discriminator GAN
D
, we calculate the loss based on both: real images Prediction p i = D ( x i ) p_i=D(x_i) pi=D(xi) 用标签 y i = 1 y_i=1 andi=The error between 1, and the prediction of the generated image p i = D ( G ( z i ) ) p_i=D(G(z_i )) pi=D(G(zi))given y i = 0 y_i=0 andi=The error between 0. Therefore, for the discriminator of GAN
, the process of minimizing the loss function can be expressed as:
min D − ( E x ∼ p X [ log D ( x ) ] + E z ∼ p Z [ log ( 1 − D ( G ( z ) ) ) ] ) \mathop {\min} \limits_{D}-(\mathbb E_{x\sim p_X}[ \log D(x)]+\mathbb E_{z\sim p_Z}[\log (1-D(G(z)))]) Dmin−(Ex∼pX[logD(x)]+ANDz∼pZ[log(1−D(G(z )))])
为了训练 GAN
の generator G
、My base generation image's ability < /span> p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 訳 y i = 1 y_i=1 andi=The error calculation loss is 1. Therefore, for the generator of GAN
, the process of minimizing the loss function can be expressed as:
min G − ( E z ∼ p Z [ log D ( G ( z ) ) ] ) \mathop {\min}\limits_{G}-(\mathbb E_{z\sim p_Z}[\log D(G(z))]) Gmin−(Ez∼pZ[logD(G(z ))])
Next, we compare the above loss function with the Wasserstein
loss function.
Wasserstein
loss (Wasserstein Loss
) is a loss function for Wasserstein GAN
(WGAN
). Different from the traditional binary cross-entropy loss function, the Wasserstein
loss introduces the labels 1
and -1
, changing the output of the discriminator from The probability value is converted into a score (score
), therefore, the discriminator of WGAN
is also often called a critic (critic
), And the discriminator is required to be a continuous function. Specifically, loss uses tag y i = 1 y_i=1 1-Lipschitz
Wasserstein
andi=1 Sum y i = − 1 y_i=-1 andi=−1 代替 y i = 1 y_i=1 andi=1 sum y i = 0 y_i=0 andi=0, and also need to remove the Sigmoid
activation function of the last layer of the discriminator, so as to predict the result p i p_i pi is not necessarily in [ 0 , 1 ] [0,1] [0,1] is within the range, it can be [ − ∞ , ∞ ] [-∞ ,∞] [−∞,∞] What's the name inside the fan? Wasserstein
Determined as follows:
− 1 n ∑ i = 1 n ( y i p i ) -\frac 1 n∑_{i=1}^n(y_ip_i )−n1i=1∑n(yipi)
When training the discriminator WGAN
D
, we will calculate the following loss: Prediction of image p i = D ( x i ) p_i=D(x_i) pi=D(xi) 用标签 y i = 1 y_i=1 andi=The error between 1, the discriminator’s prediction of the generated image p i = D ( G ( z i ) ) p_i=D(G (z_i)) pi=D(G(zi)) 用标签 y i = − 1 y_i=-1 andi=The error between −1. Therefore, for the WGAN
discriminator, the process of minimizing the loss function can be expressed as:
min D − ( E x ∼ p X [ D ( x ) ] − E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ D - (\mathbb E_{x\sim p_X}[D(x)] - \mathbb E_{z\sim p_Z}[D(G(z))]) Dmin−(Ex∼pX[D(x)]−ANDz∼pZ[D(G( z))])
In other words, WGAN
the discriminator tries to maximize its accuracy over real images The difference between the prediction and the prediction of the generated image, with the real image scoring higher.
And for the training of WGAN
generator G
, we base the prediction of the generated image on the discriminator p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 訳 y i = 1 y_i=1 andi=1 Calculate the loss. Therefore, for the WGAN
generator, the minimization loss function can be expressed as:
min G − ( E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ G - (\mathbb E_{z\sim p_Z}[D(G(z))]) Gmin−(Ez∼pZ[D(G( z))])
In other words, the WGAN
generator attempts to generate the discriminator with the Images with high scores are judged to be real images (i.e., causing the discriminator to think they are real).
1.2 Lipschitz constraint
Since we allow the discriminator output [ − ∞ , ∞ ] [-∞,∞] [−∞,∞] instead of limiting the output to < as in the Sigmoid
function /span> [ 0 , 1 ] [0,1] [0,1], so Wasserstein
losses can be very large. Therefore, in order for the Wasserstein
loss function to work properly, an additional constraint on the discriminator is required, namely the 1-Lipschitz
continuity constraint. The discriminator is a function that converts an image into a prediction D
, if for any two input images x 1 x_1 x1 sum x 2 x_2 x2, the discriminator functionD
satisfies the following inequality, then the function is 1-Lipschitz
continuous:
∣ D ( x 1 ) − D ( x 2 ) ∣ ∣ x 1 − x 2 ∣ ≤ 1 \frac {|D(x_1) - D(x_2)|}{|x_1 - x_2|} ≤ 1 ∣x1−x2∣∣D(x1)−D(x2)∣≤1
inside, ∣ x 1 − x 2 ∣ |x_1 - x_2| ∣x1−x2∣ represents the absolute value of the difference between the average pixels of the two images, ∣ D ( x 1 ) − D ( x 2 ) ∣ |D(x_1) - D(x_2)| ∣D(x1)−D(x2)∣ represents the absolute value between discriminator predictions. This means that the discriminator's predicted rate of change is bounded in any case (i.e. the absolute value of the gradient cannot be greater than 1
). As you can see in the Lipschitz
continuous one-dimensional function in the figure below, no matter where the cone is placed, the curve never goes inside the cone. In other words, the rate at which any point on the curve can rise or fall is finite.
1.3 Enforcing Lipschitz constraints
In the originalWGAN
paper, the author clipped the weight of the discriminator to a smaller range after each training session [ − 0.01 , 0.01 ] [-0.01, 0.01] [−0.01,0.01] to enforce the Lipschitz
constraint.
Since we clipped the weight of the discriminator, the learning ability of the discriminator is greatly reduced. Therefore, in fact, weight clipping is not an ideal forcing Lipschitz
constraint The way. A strong discriminator is crucial to WGAN
's success because without accurate gradients, the generator cannot learn how to adjust its weights to produce better samples.
Therefore, researchers have proposed many other methods to enforce Lipschitz
constraints and improve WGAN
the ability to learn complex features. One such method is with gradient penalty (Gradient Penalty
). directly enforces the constraint by including a gradient penalty term in the discriminator's loss function if the gradient norm deviates from , this term will penalize the model, thereby making the training process more stable. Next, add this additional gradient penalty term to the discriminator loss function. Wasserstein GAN
Lipschitz
1
1.4 Gradient penalty loss
The following figure showsWGAN-GP
the training process of the discriminator. Compared with the training process of the original discriminator, we can see that the key improvement is to use the gradient penalty loss as the overall loss function part of , and used with Wasserstein
losses from real and generated images.
The gradient penalty loss measures the squared difference between the gradient norm of the prediction with respect to the input image and 1
. The model tends to find weights that minimize the gradient penalty, thus encouraging the model to comply with the Lipschitz
constraint.
During the training process, it is very difficult to calculate the gradient at every point, so WGAN-GP
the gradient is only evaluated at a few points. To ensure balance, we use a set of interpolated images that are interpolated pixel by pixel (Interpolation
) at random positions between the real and fake images to generate some images.
Use Keras
to calculate the gradient penalty:
def gradient_penalty(self, batch_size, real_images, fake_images):
# 批数据中的每个图像都会得到一个 0~1 之间的随机数字,存储到向量 alpha 中
alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
# 计算一组插值图像
diff = fake_images - real_images
interpolated = real_images + alpha * diff
with tf.GradientTape() as gp_tape:
gp_tape.watch(interpolated)
# 使用判别器对每个插值图像进行评分
pred = self.critic(interpolated, training=True)
# 计算插值图像 (y_pred) 的预测对于输入 interpolated_samples) 的梯度
grads = gp_tape.gradient(pred, [interpolated])[0]
# 计算这个向量的 L2 范数(即欧几里得长度)
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
# 函数返回 L2 范数与 1 之差的平方的均值
gp = tf.reduce_mean((norm - 1.0) ** 2)
return gp
1.5 Training WGAN-GP
One advantage of using Wasserstein
loss function is that you no longer need to worry about balancing the training of the discriminator and generator. In fact, when using Wasserstein
loss, the discriminator must be trained to convergence before updating the generator to ensure that the gradient of the generator update is accurate. This is in contrast to criterion GAN
where it is important not to let the discriminator become too strong. Therefore, using we can simply train the discriminator multiple times between generator updates to ensure it is close to convergence. Typically the generator is updated once and the discriminator is updated three to five times. After understanding the two key concepts of ( loss and gradient penalty term), use Implementation:GAN
Wasserstein GAN
WGAN-GP
Wasserstein
Keras
WGAN-GP
def train_step(self, real_images):
batch_size = tf.shape(real_images)[0]
# 对判别器进行三次更新
for i in range(self.critic_steps):
random_latent_vectors = tf.random.normal(
shape=(batch_size, self.latent_dim)
)
with tf.GradientTape() as tape:
fake_images = self.generator(
random_latent_vectors, training=True
)
fake_predictions = self.critic(fake_images, training=True)
real_predictions = self.critic(real_images, training=True)
# 计算判别器的 Wasserstein 损失
c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)
# 计算梯度惩罚项
c_gp = self.gradient_penalty(batch_size, real_images, fake_images)
# 判别器损失函数是 Wasserstein 损失和梯度惩罚的加权和
c_loss = c_wass_loss + c_gp * self.gp_weight
c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
# 更新判别器的权重
self.c_optimizer.apply_gradients(
zip(c_gradient, self.critic.trainable_variables)
)
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
with tf.GradientTape() as tape:
fake_images = self.generator(random_latent_vectors, training=True)
fake_predictions = self.critic(fake_images, training=True)
# 计算生成器的 Wasserstein 损失
g_loss = -tf.reduce_mean(fake_predictions)
gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
# 更新生成器的权重
self.g_optimizer.apply_gradients(
zip(gen_gradient, self.generator.trainable_variables)
)
self.c_loss_metric.update_state(c_loss)
self.c_wass_loss_metric.update_state(c_wass_loss)
self.c_gp_metric.update_state(c_gp)
self.g_loss_metric.update_state(g_loss)
return {
m.name: m.result() for m in self.metrics}
One final point to note before training WGAN-GP
is that the discriminator should not use batch normalization. This is because batch normalization creates correlations between the same batch of images, making the gradient penalty loss less effective. Experiments show that even without batch normalization in the discriminator, WGAN-GP
can still output excellent results.
2. Key differences between GAN and WGAN-GP
In summary, the following exists between the criteria GAN
and WGAN-GP
:
WGAN-GP
UseWasserstein
DisappearWGAN-GP
uses1
for real image tags and-1
for fake image tags- The last layer of the discriminator is not used
sigmoid
activation - Include a gradient penalty term in the discriminator’s loss function
- Each time the generator is trained, the weights are updated and the discriminator needs to be trained multiple times.
- There is no batch normalization layer in the discriminator
3. WGAN-GP model analysis
After training 25
epoch
, the generator of the WGAN-GP
model is able to generate reasonable images:
The model has learned important high-level features of the face with no signs of model collapse.
If we compare the output of WGAN-GP
with the variational autoencoder (Variational Autoencoder
, VAE
) Comparing the output of , you can see that the image produced by WGAN-GP
is generally sharper. Overall, VAE
tends to produce images with blurry color boundaries, while GAN
produces images that are more clearly defined. GAN
is generally more difficult to train than VAE
and takes longer to obtain satisfactory data quality.
summary
In this section, we learned how to use the Wasserstein
loss function to solve the classic GAN
problems such as mode collapse and vanishing gradient during training, Makes GAN
training more predictable and reliable. WGAN-GP
Impose a constraint on the training process by adding a term to the loss function that makes the gradient norm point to 1
. 1-Lipschitz
Series link
AIGC Practical Combat—Introduction to Generative Models
AIGC Practical Combat—Deep Learning (DL)
AIGC Practical Combat—Convolutional Neural Network (Convolutional Neural Network, CNN)
AIGC Practical Combat—Autoencoder (Autoencoder)
AIGC Practical Combat—Variational Autoencoder (VAE)
AIGC Practical Combat—Using variational autoencoders to generate facial images
AIGC Practical Combat—Generative Adversarial Network (GAN)