Pytorch implements a simple generative adversarial network GAN

Recently, I read some GAN information and summarized my confusing content.

Generative model

        The deep learning models that we usually come across in the past are generally discriminant models , that is, training models through training samples, and then using the models to discriminate or predict new samples. The discriminant model embodies the learning ability of deep learning . However, the power of artificial intelligence should not only be learned from the known, but also have the creative ability to be truly interesting. The generative model embodies the creative ability of deep learning. Contrary to the work flow of discriminant model, generative model is to generate new samples according to rules.
        The generative model mainly includes the two ideas of variational autoencoder (VAE) and generative confrontation network (GAN) . VAE is based on Bayesian inference . Its purpose is to potentially model and sample new data from the model. GAN is the use of game theory to find the discriminator network (D) and the generator network (G) that achieve the Nash equilibrium ( how to understand the Nash equilibrium point in a popular way ).

GAN architecture

        GAN ’s intuitive understanding can imagine a famous painting forger who wants to forge a painting of Leonardo da Vinci. At the beginning, the forgery technique is not good, but he mixes some of his fakes with Leonardo ’s works. Assessment, and feedback to the counterfeiters about the authenticity. Based on feedback, counterfeiters improve their fakes. Over time, the ability of counterfeiters has become stronger and stronger, and the ability of connoisseurs has also become stronger. The fakes are more and more like real paintings.

        The above is the principle of GAN. A counterfeiter G, a connoisseur D. The purpose of their training is to defeat each other.

        The following figure is a GAN architecture diagram I took from the book, which is simple and clear.

GAN's loss function

     

  Assuming that x represents an image and D (x) represents a discriminant network, which is a binary classifier, then its output is the probability that picture x comes from training data (rather than producing a fake picture output by the network). For the generation network, first define the data z sampled from the standard normal distribution, then G (z) represents the generator function that maps the vector z to space. The goal of G is to estimate the distribution of training data (Pdate) to generate false samples. Therefore D (G (z)) is the probability that the output of the network G is a real image. That is, discriminating network D and generating network G are playing a very small game, where D tries to maximize the probability that it correctly distinguishes true and false data (logD (x)), and G tries to minimize D's prediction that its output is false Probability (log (1-d (G (x)))).

 

Pytorch implements a GAN

      

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt



# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001           # learning rate for generator
LR_D = 0.0001           # learning rate for discriminator
N_IDEAS = 5             # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15     # it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])
print(PAINT_POINTS)
# show our beautiful painting range
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.legend(loc='upper right')
plt.show()


def artist_works():     # painting from the famous artist (real target)
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a-1)
    paintings = torch.from_numpy(paintings).float()
    return paintings

G = nn.Sequential(                      # Generator
    nn.Linear(N_IDEAS, 128),            # random ideas (could from normal distribution)
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),     # making a painting from these random ideas
)

D = nn.Sequential(                      # Discriminator
    nn.Linear(ART_COMPONENTS, 128),     # receive art work either from the famous artist or a newbie like G
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),                       # tell the probability that the art work is made by artist
)

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

plt.ion()   # something about continuous plotting

for step in range(10000):
    artist_paintings = artist_works()           # real painting from artist
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideas
    G_paintings = G(G_ideas)                    # fake painting from G (random ideas)

    prob_artist0 = D(artist_paintings)          # D try to increase this prob
    prob_artist1 = D(G_paintings)               # D try to reduce this prob

    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)      # reusing computational graph
    opt_D.step()

    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()

    if step % 50 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.01)

plt.ioff()
plt.show()

 

Reference materials: "Python deep learning based on pytorch", "Principles and Practice of Deep Learning and Image Recognition", don't bother with Python

 

 

 

 

 

 

 

 

 

Published 134 original articles · praised 38 · 90,000 views +

Guess you like

Origin blog.csdn.net/rytyy/article/details/105128976