Revealed: Wasserstein GAN and Gradient Penalty (WGAN-GP)

1. Description

        What is gradient penalty? Why is it better than gradient cropping? How to implement gradient penalty? When mentioning GAN adversarial networks, the concept of Wasserstein distance cannot be avoided. This article is a series of readings with the purpose of revealing and discussing some important concepts surrounding Wasserstein-GAN modeling.

Figure 1 (left) The gradient norm either explodes or disappears when using weighted clipping without using GP. (Right) Unlike GP, weight clipping pushes the weights to two values. 

2. Background information

        In this article, we will study Wasserstein GAN with gradient penalty. Although the original Wasserstein GAN [2] improved training stability, there were still cases where poor samples were generated or failed to converge. To review, the cost function of WGAN is:

Formula 1: WGAN value function.

        where  is  1-Lipschitz continuous. The problem with WGANs is primarily because of the weight clipping method used to enforce Lipschitz continuity to critics. WGAN-GP replaces weight clipping with a constraint on the critic’s gradient norm to enforce Lipschitz continuity. This allows for more stable network training than WGAN and requires less hyperparameter tuning. WGAN-GP and this article build on Wasserstein GANs, which have been discussed in the previous article in the series of revelations . Check out the post below to learn about WGAN.

Report 1

Differentiable optimal 1-Lipschitz functions that minimize f* of Equation 1 have unit gradient norm almost everywhere under Pr and Pg.

Pr and Pg are the true and false distributions respectively. The proof of statement 1 can be found in [1].

3. Gradient shearing problem

3.1 Capacity is not fully utilized

Figure 2: WGAN critic (top) value surface learned using gradient clipping, (bottom) value surface learned using gradient penalty. Image source: [1]

Using weight clipping to enforce k-Lipschitz constraints leads to critics learning very simple functions.

From statement 1, we know that the gradient norm of the optimal critic is 1 almost everywhere in Pr and Pg. In the weight clipping setting, the critic tries to reach its maximum gradient norm  k and ultimately learns simple functions.

Figure 2 shows this effect. The critic is trained to converge to a fixed generating distribution (Pg) as the actual distribution (Pr) + unit Gaussian noise. We can clearly see that the critic trained using weight clipping ends up learning simple functions and failing to capture higher moments, while the critic trained using gradient penalty does not have this problem.

3.2 Gradient explosion and disappearance

The interaction between weight constraints and loss functions makes training of WGAN difficult and causes gradients to explode or disappear.

This can be seen clearly in Figure 1 (left), where the annotator's weights explode or disappear at different clipping values. Figure 1 (right) also shows that gradient clipping pushes the annotator weights to two extreme clipping values. On the other hand, critics trained with gradient penalties do not encounter such problems.

4. Gradient penalty

The idea of ​​gradient penalty is to enforce a constraint such that the gradient of the critic output has unit norm with the input (statement 1).

The authors propose a soft version of this constraint that penalizes the gradient norm of the sample x̂∈ P x̂ . The new goal is

Formula 2: Critic loss function

In Equation 2, the term on the left side of the sum is the original critic loss, and the term on the right side of the sum is the gradient penalty.

P x̂  is the distribution obtained by uniformly sampling along a straight line between the real distribution and the generated distributions Pr and Pg. This is done because the optimal annotator has a straight line of unit gradient norm between samples coupled from Pr and Pg.

λ, the penalty coefficient is used to weight the gradient penalty term. In the paper, the authors set λ = 10 for all experiments.

Batch normalization is no longer used in annotations because batch norm maps a batch of inputs to a batch of outputs. In our case , we want to be able to find the gradient of each output, wrt their respective inputs.

5. Code examples

5.1 Gradient penalty

 The implementation of gradient penalty is shown below.

def compute_gp(netD, real_data, fake_data):
        batch_size = real_data.size(0)
        # Sample Epsilon from uniform distribution
        eps = torch.rand(batch_size, 1, 1, 1).to(real_data.device)
        eps = eps.expand_as(real_data)
        
        # Interpolation between real data and fake data.
        interpolation = eps * real_data + (1 - eps) * fake_data
        
        # get logits for interpolated images
        interp_logits = netD(interpolation)
        grad_outputs = torch.ones_like(interp_logits)
        
        # Compute Gradients
        gradients = autograd.grad(
            outputs=interp_logits,
            inputs=interpolation,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
        )[0]
        
        # Compute and return Gradient Norm
        gradients = gradients.view(batch_size, -1)
        grad_norm = gradients.norm(2, 1)
        return torch.mean((grad_norm - 1) ** 2)

5.2 About WGAN-GP code

The code for training the WGAN-GP model can be found here:

5.3 Output

Figure 3: Image generated by WGAN-GP model. Note that the results are early results and training will be stopped once it is confirmed that the model is training as expected.

 

        Fig.3 shows some early results of training WGAN-GP. Note that the images in Figure 3 are early results and training will be stopped once it is confirmed that the model is training as expected. The model was not trained to converge.

6. Conclusion

        Wasserstein GAN provides much-needed stability in training generative adversarial networks. However, using gradient clipping leads to various problems such as exploding and vanishing gradients. The gradient penalty constraint is not affected by these issues, thus allowing easier optimization and convergence compared to the original WGAN. This article examines these issues, introduces gradient penalty constraints, and shows how to implement gradient penalties using PyTorch. Finally, the code for training the WGAN-GP model is provided along with the output of some early stages. Adithya Sankar

7. Quote

[1] Gulrajani, Ishaan, et al. “Improved training of wasserstein gans”. arXiv preprint arXiv:1704.00028 (2017 ).

[2] Arjovsky, Martin, Sumis Chintala, and Leon Botu. “Wasserstein generative adversarial networks.” International Conference on Machine Learning . PMLR, 2017.

[3] GitHub - aadhithya/gan-zoo-pytorch: A zoo of GAN implementations

Guess you like

Origin blog.csdn.net/gongdiwudu/article/details/132840286