Aprendizaje profundo - generación de red de confrontación GAN

concepto basico

descripción general

GAN es un modelo de aprendizaje profundo, que es un algoritmo de aprendizaje no supervisado que se utiliza para generar datos realistas como imágenes, audio, texto, etc. a partir de ruido aleatorio. La estructura de GAN consta de dos redes neuronales: Generador y Discriminador, que compiten entre sí para impulsar todo el aprendizaje del modelo.

Dos componentes principales:

1. Generador:

El objetivo de un generador es transformar el ruido aleatorio (generalmente un vector muestreado de una distribución normal o uniforme) en muestras de datos realistas. Este proceso se puede entender como el generador aprende la distribución de los datos y trata de crear nuevas muestras similares a los datos reales. Inicialmente, la salida del generador puede ser aleatoria, pero a medida que avanza el entrenamiento, generará gradualmente datos más realistas para engañar al discriminador.

2. Discriminador (Discriminador):

La tarea del discriminador es clasificar la muestra de datos de entrada, es decir, juzgar si se trata de datos reales o falsos generados por el generador. El discriminador es un clasificador binario cuyo objetivo es distinguir los datos reales de los datos falsos generados por el generador con la mayor precisión posible.

proceso de entrenamiento

1. Al comienzo del entrenamiento, el generador genera aleatoriamente algunas muestras de datos falsos y se las proporciona al discriminador junto con los datos reales.
2. El discriminador clasifica los datos de entrada y genera una estimación de probabilidad (0 para datos falsos y 1 para datos reales).
3. Según la salida del discriminador, calcule la probabilidad de que los datos generados por el generador se juzguen como datos reales y use esta probabilidad como la "pérdida" del generador.
4. Luego, de acuerdo con la pérdida del generador, los parámetros del generador se actualizan para que el generador pueda generar muestras de datos más realistas.
5. Luego, vuelva a generar aleatoriamente un lote de muestras de datos falsos y entrégueselos al discriminador junto con los datos reales, y repita el proceso anterior.

A través de este proceso de competencia y juego, el generador y el discriminador optimizan gradualmente sus capacidades hasta que el generador puede generar muestras de datos muy realistas, mientras que el discriminador no puede distinguir con precisión entre lo real y lo falso.

código y comentarios

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Hyper Parameters
BATCH_SIZE = 64
# 生成器学习率
LR_G = 0.0001  # learning rate for generator
# 判别器学习率
LR_D = 0.0001  # learning rate for discriminator
N_IDEAS = 5  # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15  # it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])


# 定义函数artist_works,用于生成来自著名艺术家的真实画作数据
def artist_works():
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a - 1)
    paintings = torch.from_numpy(paintings).float()
    return paintings


# 定义生成器(Generator)和判别器(Discriminator)
# 初级画家
G = nn.Sequential(
    nn.Linear(N_IDEAS, 128),  # 生成器输入为随机噪声数据
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),  # 生成器输出为生成的艺术作品
)

# 初级鉴赏家
D = nn.Sequential(
    nn.Linear(ART_COMPONENTS, 128),  # 判别器输入为艺术作品数据
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),  # 判别器输出为对艺术作品的真假概率
)

# 定义两个优化器,分别用于优化生成器和判别器的参数
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

# 开始GAN的训练
plt.ion()  # 打开交互式绘图

for step in range(10000):
    # 获取来自艺术家的真实画作数据
    artist_paintings = artist_works()
    # 生成随机的噪声数据
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True)
    # 生成器生成假的艺术画作
    G_paintings = G(G_ideas)

    # 判别器对生成的画作进行判断,试图减小判别器对生成画作的概率
    prob_artist1 = D(G_paintings)
    # 计算生成器的损失
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    opt_G.zero_grad()  # 清空生成器的梯度
    G_loss.backward()  # 反向传播计算生成器的梯度
    opt_G.step()  # 优化生成器的参数

    # 判别器对真实画作进行判断,试图增大判别器对真实画作的概率
    prob_artist0 = D(artist_paintings)
    # 判别器对生成的画作进行判断,试图减小判别器对生成画作的概率
    prob_artist1 = D(G_paintings.detach())
    # 计算判别器的损失
    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))

    opt_D.zero_grad()  # 清空判别器的梯度
    D_loss.backward(retain_graph=True)  # 反向传播计算判别器的梯度(保留计算图以供下一次计算)
    opt_D.step()  # 优化判别器的参数

    if step % 50 == 0:  # 每隔一段时间进行绘图显示
        # 绘制生成的画作、上界和下界
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting')
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={
    
    'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={
    
    'size': 13})
        plt.ylim((0, 3))
        plt.legend(loc='upper right', fontsize=10)
        plt.draw()
        plt.pause(0.01)

plt.ioff()  # 关闭交互式绘图
plt.show()  # 展示绘制的图像

resultado de la operación

inserte la descripción de la imagen aquí

Supongo que te gusta

Origin blog.csdn.net/Elon15/article/details/131832395
Recomendado
Clasificación