Deep learning - generate confrontation network GAN

basic concept

overview

GAN is a deep learning model, which is an unsupervised learning algorithm used to generate realistic data such as images, audio, text, etc. from random noise. The structure of GAN consists of two neural networks: Generator and Discriminator, which compete with each other to drive the entire model learning.

Two main components:

1. Generator:

The goal of a generator is to transform random noise (usually a vector sampled from a normal or uniform distribution) into realistic data samples. This process can be understood as the generator learns the distribution of the data and tries to create new samples similar to the real data. Initially, the output of the generator may be random, but as training progresses, it will gradually generate more realistic data to fool the discriminator.

2. Discriminator (Discriminator):

The task of the discriminator is to classify the input data sample, that is, to judge whether it is real data or fake data generated by the generator. The discriminator is a binary classifier whose goal is to distinguish real data from fake data generated by the generator as accurately as possible.

training process

1. At the beginning of training, the generator randomly generates some fake data samples and provides them to the discriminator along with the real data.
2. The discriminator classifies the input data and outputs a probability estimate (0 for fake data and 1 for real data).
3. According to the output of the discriminator, calculate the probability that the data generated by the generator is judged as real data, and use this probability as the "loss" of the generator.
4. Next, according to the loss of the generator, the parameters of the generator are updated so that the generator can generate more realistic data samples.
5. Then, randomly generate a batch of fake data samples again, and provide them to the discriminator together with the real data, and repeat the above process.

Through this process of competition and game, the generator and the discriminator gradually optimize their capabilities until the generator can generate highly realistic data samples, while the discriminator cannot accurately distinguish between real and fake.

code and comments

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Hyper Parameters
BATCH_SIZE = 64
# 生成器学习率
LR_G = 0.0001  # learning rate for generator
# 判别器学习率
LR_D = 0.0001  # learning rate for discriminator
N_IDEAS = 5  # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15  # it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])


# 定义函数artist_works,用于生成来自著名艺术家的真实画作数据
def artist_works():
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a - 1)
    paintings = torch.from_numpy(paintings).float()
    return paintings


# 定义生成器(Generator)和判别器(Discriminator)
# 初级画家
G = nn.Sequential(
    nn.Linear(N_IDEAS, 128),  # 生成器输入为随机噪声数据
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),  # 生成器输出为生成的艺术作品
)

# 初级鉴赏家
D = nn.Sequential(
    nn.Linear(ART_COMPONENTS, 128),  # 判别器输入为艺术作品数据
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),  # 判别器输出为对艺术作品的真假概率
)

# 定义两个优化器,分别用于优化生成器和判别器的参数
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

# 开始GAN的训练
plt.ion()  # 打开交互式绘图

for step in range(10000):
    # 获取来自艺术家的真实画作数据
    artist_paintings = artist_works()
    # 生成随机的噪声数据
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True)
    # 生成器生成假的艺术画作
    G_paintings = G(G_ideas)

    # 判别器对生成的画作进行判断,试图减小判别器对生成画作的概率
    prob_artist1 = D(G_paintings)
    # 计算生成器的损失
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    opt_G.zero_grad()  # 清空生成器的梯度
    G_loss.backward()  # 反向传播计算生成器的梯度
    opt_G.step()  # 优化生成器的参数

    # 判别器对真实画作进行判断,试图增大判别器对真实画作的概率
    prob_artist0 = D(artist_paintings)
    # 判别器对生成的画作进行判断,试图减小判别器对生成画作的概率
    prob_artist1 = D(G_paintings.detach())
    # 计算判别器的损失
    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))

    opt_D.zero_grad()  # 清空判别器的梯度
    D_loss.backward(retain_graph=True)  # 反向传播计算判别器的梯度(保留计算图以供下一次计算)
    opt_D.step()  # 优化判别器的参数

    if step % 50 == 0:  # 每隔一段时间进行绘图显示
        # 绘制生成的画作、上界和下界
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting')
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={
    
    'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={
    
    'size': 13})
        plt.ylim((0, 3))
        plt.legend(loc='upper right', fontsize=10)
        plt.draw()
        plt.pause(0.01)

plt.ioff()  # 关闭交互式绘图
plt.show()  # 展示绘制的图像

operation result

insert image description here

Guess you like

Origin blog.csdn.net/Elon15/article/details/131832395