Generative Adversial Network (GAN) Principle Introduction

  Generative confrontation network (GAN) is a relatively large family in deep learning. Its main function is to realize the generation (or creation) of images, music or text. The discriminator is trained against each other continuously. In the end, it makes it difficult for the discriminator to distinguish between the data (pictures, audio, etc.) generated by the generator and the real data. Therefore, for generative confrontation networks, our ultimate goal is generally to get the generator, because after the training is over, we need to get the works created by the neural network.

One: Basic principles

  The basic idea of ​​generating an adversarial network is the process of continuous confrontation between the generating network (generator, generator) and the discriminative network (discriminator, discriminator) during the training process. Therefore, no matter how complex the GAN model is, the basic idea is still the confrontation between these two networks, and the basic structure must have a generator and a discriminator.

  In order to understand the general process of GAN, Tutu uses a classic example to illustrate. We assume that the generator is an ordinary painter and the discriminator is a famous painting discriminator. At first, the counterfeit of the painter's painting was easily identified by the appraiser, and then the artist tried to make the painting more like the real famous painting, but the appraiser was still able to recognize it later. However, in the process of being identified again and again, the painter's painting skills gradually became more advanced, and the identification level of the appraiser gradually improved. Therefore, in the end, the painter may paint so realistically that it is almost the same as the real famous painting, so that the appraiser can't tell the difference.

  The above example is just a very general metaphor. If it is more detailed, the actual process is: give some real paintings and fake paintings of the artist to the appraiser, tell the appraiser which is real and which is fake, this process is actually training the appraiser; then the painter draws several paintings For the appraiser, if some paintings are identified by the appraiser, the painter needs to reflect on these identified paintings, so that the subsequent paintings should not be identified by the appraiser as much as possible. This process is to train the painter; Give the fake painting with the painter to the discriminator, tell him which one is real and which one is fake... This process is carried out alternately until the end of the training.

  The above process is actually the general idea of ​​GAN.

  Next is the specific structure of GAN. Tutu takes the generation of pictures as an example, which is also the most classic application of GAN, and here we will not explain the stack-GAN and other models that generate images from text, but only the simplest GAN.

  GAN can generally be regarded as unsupervised learning, because the data we train has only real pictures and no labels. The labels used here are only the "true" of the real data and the "false" of the pictures generated by the generator.

1. Generator

  First of all, for the generator, its interior is generally a network composed of multi-layer convolution, fully connected layers, etc., and uses upsampling to generate a picture of a suitable size through the received noise. The noise here is actually a random number that obeys a certain distribution. Generally, a normal distribution random number is selected, and the random number forms a vector of length n, and each such vector will eventually generate a corresponding picture. In a sense, it is the existence of this random number that the pictures generated each time will be different, but whether the pictures look real depends on the generator. Of course, the probability distribution of random numbers also has an impact on the image. If Gaussian noise is used for training, this noise should also be used when using it.

2. Discriminator (discriminator)

  For the discriminator, its interior is generally a network composed of multi-layer convolution, full connection and other layers, and it uses downsampling. It receives a batch of pictures (batch, c, w, h), and each picture is correspondingly marked as True or false (batch, 1) (the label can also be represented by one-hot encoding (batch, 2)), where the false picture is generated by the generator. The training process of the discriminator is consistent with our previous supervised learning training method.

3. Generate adversarial model (GAN, adversarial model)

  After building the generator and discriminator, combine the two to form the confrontation network GAN, which is used to train the generation network.

  For GAN, it receives a batch of noise, and the output is a "true" or "false" label. If it is true, it means that the picture generated by the generator has fooled the discriminator. Otherwise, adjust the internal parameters of the generator according to this loss. This training process is also the same as the previous supervised learning training method, but the discriminator parameters here cannot be changed, only the internal parameters of the generator can be changed. After all, this is the process of training the generator, not the training discriminator; and the label is "True" , because we want the image generated by the generator to look more real.

  The whole training process is:

(1). Sample a batch of noise vectors with a length of n from a Gaussian distribution.

(2). Using the noise vector in (1), use the generator to generate a fake image.

(3). Take a batch of real images from the real data, mix them with the fake images in (2), make labels, and train the discriminator.

(4). Then sample a batch of noise vectors with a length of n from the Gaussian distribution, label "True", and train GAN. At this time, the discriminator parameter in GAN cannot be updated, and only the generator is trained.

(5). Repeat the above steps according to the specified number of rounds.

Two: Basic framework

import torch
import  numpy as np
from torch.utils.data import DataLoader
from torch import nn
import argparse

class Generator(nn.Module):
    '''生成器'''
    def __init__(self):
        super(Generator, self).__init__()
        pass
    def forward(self,input):
        pass
class Discriminator(nn.Module):
    '''判别器'''
    def __init__(self):
        super(Discriminator, self).__init__()
        pass
    def forward(self,input):
        pass
class GAN(nn.Module):
    '''GAN模型'''
    def __init__(self):
        super(GAN, self).__init__()
        self.gene=Generator()
        self.gene.requires_grad_(True)
        self.disc=Discriminator()
        self.disc.requires_grad_(False)
    def forward(self,input):
        out=self.gene(input)
        out=self.disc(out)
        return out
class dataset:
    '''真实图片数据集'''
    def __init__(self):
        pass
    def __len__(self):
        pass
    def __getitem__(self, item):
        pass
if __name__=='__main__':
    parser=argparse.ArgumentParser()
    parser.add_argument('--epoch',type=int,default=10,help='the train epoch')
    parser.add_argument('--n',type=int,default=200,help='the length of noise')
    parser.add_argument('--noise_batch',type=int,default=10,help='the batch size of noise')
    parser.add_argument('--true_batch',type=int,default=10,help='the batch size of true picture')
    opt=parser.parse_args()
    gene=Generator()
    disc=Discriminator()
    gan=GAN()
    
    disc_optim=torch.optim.Adam(disc.parameters())
    gan_optim=torch.optim.Adam(gan.parameters())
    criterion=nn.MSELoss()
    for i in range(opt.epoch):
        true_pict = DataLoader(dataset(), batch_size=opt.true_batch, shuffle=True)
        for batch in true_pict:
            noise=torch.tensor(np.random.normal((opt.noise_batch,opt.n)),dtype=torch.float32)
            false_pict=gene(noise)
            label_true=torch.tensor(np.ones(shape=opt.true_batch),dtype=torch.float32)
            label_fake=torch.tensor(np.ones(shape=opt.noise_batch),dtype=torch.float32)
            loss_true=criterion(batch,label_true)
            loss_fake=criterion(false_pict,label_fake)
            disc_optim.zero_grad()
            loss_fake.backward()
            loss_true.backward()
            loss_fake.step()
            loss_true.step()
            
            noise1=torch.tensor(np.random.normal((opt.noise_batch,opt.n)),dtype=torch.float32)
            pict=gene(noise1)
            label=torch.tensor(np.ones(shape=(opt.noise_batch,opt.n)),dtype=torch.float32)
            loss_gan=criterion(pict,label)
            loss_gan.zero_grad()
            loss_gan.backward()
            loss_gan.step()

Of course, regarding internal training batches and other issues, as well as the number of training times for the discriminant network and the generated network, sometimes specific analysis of specific issues is required.

Three: Summary:

The development time of generative confrontation network is not long, but there are already a lot of GAN models, and many new methods are derived from GAN. Generative confrontation network not only opens the door of deep learning in the field of creation, but more importantly, it brings a new method and idea, which has a profound impact on many fields.

Guess you like

Origin blog.csdn.net/weixin_60737527/article/details/127475864