Training GANs on spatiotemporal data: a practical guide (Part 01/3)

Part 1: A closer look at the most notorious instabilities in GAN training.

1. Description

        GANs are by far the most popular deep generative models, mainly because they have recently produced incredible results on image generation tasks. However, GANs are not easy to train because their basic design introduces numerous instabilities. If you've ever tried to train a GAN on anything other than MNIST, you'll quickly realize that all the talk about the pain of training them (and the associated research fields trying to solve this problem) doesn't exaggerate the problem.

2. The instability of GAN

        We will systematically address the causes and solutions to these notorious instabilities that we have found to work empirically well in our experiments after extensively trying almost every trick in the book. This three-part series is a practical guide to training GANs, focusing on spatiotemporal data generation, and is structured as follows:

1.  Part 1: A closer look at the most notorious instabilities in GAN training .

2.  Part 2: Possible solutions to the common pitfalls discussed in Part 1 .

3.  Part 3: Special cases of training GANs on spatiotemporal data - metrics to track, unique complexities, and their solutions

        The instabilities and solutions discussed in this series are model and use case agnostic , but also spatio-temporal situations. They are a good starting point for any GAN training exercise. In this article, we discuss why training GANs is so elusive by detailing the most notorious instabilities in GAN training. We will study a) how the imbalance between discriminator (D) and generator (G) training can lead to mode collapse and silent learning due to vanishing gradients; b) the sensitivity of GANs to hyperparameters, and c) the Misleading GAN loss in model performance.

        [Note: We assume that the reader of this article has a prerequisite knowledge of GAN basics and also has some experience training GANs at some point. For this reason, we will skip “What are GANs?” and refer readers to this article for a quick refresher. ]

3. Why is training GAN so elusive?

        In this section, we detail some of the most notorious instabilities in GAN training and detail every possible solution that worked well in practice in our experiments. Having said that, it is recommended to run the first few iterations on vanilla settings to detect which of the following pitfalls are observed in the architecture and task at hand. You can then iteratively implement the above solutions ( ordered according to their complexity and their effectiveness based on our experience ) to further stabilize the training. Please note that these tips are intended only as a starting point for direction and are not an exhaustive list of one-time solutions. Readers are advised to further explore its architecture and training dynamics for optimal results.

3.1. Imbalance between generator and discriminator

        It's easy to tell if a painting is by Van Gogh, but it's very difficult to actually make a painting. Therefore, G's task is considered more difficult than D's task. At the same time, G's ability to learn to generate realistic output depends on how well D is trained. Optimal D will give G rich signals to learn and improve its generation. Therefore, it is important to balance the training of G and D to obtain optimal learning conditions.

        GAN is based on a zero-sum non-cooperative game and attempts to achieve a Nash equilibrium. However, it is known that some cost functions do not converge with gradient descent, especially for non-convex games. This introduces a lot of instability in GAN training because the G and D training steps are unbalanced in the min-max game, resulting in suboptimal learning gradients. These instabilities are discussed below:

        1. The gradient disappears:

        Whether D should be better than G or vice versa for optimal GAN ​​training needs to be answered by looking at the following parameters -

a) If D becomes too good too fast, the gradient of G disappears and it can never catch up.

b) On the other hand, if D is suboptimal, then G can easily fool it with gibberish since D's prediction performance is poor. This again results in no gradients to learn from, resulting in no improvement in G output.

        So, ideally, G and D should be better than the other in a periodic manner. If you see one of these losses moving monotonically in any direction, then your GAN training is most likely broken.

        2.  Mode collapse: If G is trained disproportionately more, it will converge to produce the same output repeatedly, thereby fooling D well without any incentive to focus on sample diversity.

3.2  Reasons for gradient disappearance

        The original GAN ​​target's generator (Ian Goodfellow, 2014) optimizes non-saturated JS divergence as follows:

        In this case, it is easy to see if the generator misses certain distribution patterns (i.e. the penalty is high when p(x ) > 0 but q(x) → 0), and if the generated data looks seems unrealistic (i.e. if p(x)  → 0 but  q(x) > 0 , the penalty is high). This forces the generator to produce higher quality output while maintaining diversity.

        However, this formulation causes the gradient of the generator to vanish when the discriminator reaches optimality. This is evident from the example below, where p and  q  are Gaussian distributed and p  has a mean of zero. The right image shows that the gradient of the JS-divergence disappears from  q1  to  q3  . This will cause the GAN generator to learn very slowly (or not at all) when the loss in these regions saturates. This situation manifests itself early in GAN training, when p and q are very different, and the task of D is easier because the approximation of G is far from the actual distribution.

3.3   Common sense about pattern folding

        Mode collapse is by far the most difficult and non-trivial problem when training GANs. Although there are many intuitive explanations for mode collapse, in practice our understanding of it is still very limited. A key intuitive explanation that has helped practitioners so far is that without enough updates on D, G is trained disproportionately. The generator eventually converges to find the best image x* that most fools D, i.e. the most realistic image from the discriminator's perspective. In this case, x* becomes independent of z, which means that for every z, it generates the same image.

        Eventually, D (when trained again) learns to discard images of this pattern as fake images. This in turn forces the generator to find the next vulnerability point and start generating it. The cat-and-mouse chase between D and G continues, with G so focused on "cheating" that it even loses the ability to detect other patterns. This can be seen in the figure above, where the top row shows the ideal learning process that G should follow. The bottom row demonstrates the case of pattern folding, where G focuses on producing one pattern well and ignores other patterns.

3.4 Sensitivity to hyperparameters

        GAN is very sensitive to hyperparameters, period. No cost function will work without good hyperparameters, so it is recommended to tune the hyperparameters extensively first rather than trying different loss functions at the beginning. Tuning hyperparameters takes time and patience, and it's important to understand the basic training dynamics of your architecture before starting to use advanced loss functions, which will introduce their own set of hyperparameters.

4. Correlation between GAN loss and generation quality

        In usual classification tasks, the cost function is related to the accuracy of the model (lower loss means lower error means higher accuracy). However, the loss of a GAN measures how well one player performs against another player in a min-max game (generator vs. discriminator). It is common for generator losses to increase but image quality to improve. Therefore, there is little correlation between loss "convergence" and generation quality when training GANs, as unstable GAN losses are often misleading. A very effective and widely accepted technique used in image generation tasks is to track training progress by visual inspection of the generated images at different training stages. But this subsequently makes model comparison more difficult and further complicates the tuning process, as it is difficult to select the best model from this subjective evaluation. However, during our experiments we quickly realized that this very critical aspect of GAN training - tracking generation progress with the right metrics - is also one of the most overlooked aspects when people talk about training GANs. Furthermore, unlike images, we cannot effectively assess training progress on spatiotemporal data "intuitively". Therefore, it becomes crucial to design and track metrics related to spatiotemporal data that objectively indicate model performance.

        Now that we have detailed some of the prominent GAN training pitfalls, the question that arises is how do we detect and solve them? We’ll discuss this topic in detail in the next blog in this series , and after extensively trying every tip in the book, we’ve come up with multiple solutions for each. We compile a list in order of their ease of implementation and their respective impact to suggest iterative enhancements to GAN training . Shantanu Chandra

·

Guess you like

Origin blog.csdn.net/gongdiwudu/article/details/132850693