Gan genera dígitos escritos a mano

1. GAN

La red adversa generativa (GAN) fue propuesta por primera vez por Ian Goodfellow y otros en 2014, y desde entonces se ha vuelto popular y se ha convertido en un modelo popular de aprendizaje profundo. Gan puede generar imágenes, pinturas y música muy realistas. En los últimos años, ha habido muchos casos en los que las pinturas generadas por Gan han ganado premios. Por ejemplo, la controvertida "Space Opera House" generada por IA hace algún tiempo ganó el primer lugar en la competencia de arte de la Feria Estatal de Colorado (Colorado State Fair), como se muestra en la siguiente figura.
inserte la descripción de la imagen aquí

Figura 1 La pintura "Space Opera House" generada por IA

Aunque todavía hay mucha controversia sobre el uso del poder de la IA para competir con los humanos, es innegable que la IA ha mostrado una gran perspectiva de aplicación en este campo.
El principio de GAN es muy simple, utiliza dos modelos, uno de los cuales genera continuamente datos "falsos", y el otro modelo juzga los datos "falsos" generados por el anterior. Si puede engañar al modelo discriminante, significa que los datos generados pueden confundirse con los reales.

En segundo lugar, los pasos de GAN

En GAN, A es el generador , que se encarga de generar datos "falsos", y B es el discriminador , que se encarga de juzgar la calidad de los datos generados por A, que es un proceso de juego.
Generador: Acepta un vector de ruido aleatorio x como entrada, genera un tensor G(x)
Discriminador: Acepta un tensor como entrada y emite su verdadero
o falso Tomando imágenes como ejemplo, todo el proceso de entrenamiento de GAN es el siguiente:
(1) El generador acepta ruido aleatorio y genera imágenes falsas
(2) El discriminador acepta los datos de la combinación de imágenes falsas e imágenes reales, y aprende a distinguir las imágenes verdaderas de las falsas (3) El generador genera nuevas imágenes y usa el discriminador para distinguir las verdaderas y las falsas falso, y al mismo tiempo, el resultado del discriminador se utiliza para distinguir el nivel falso
.
(4) Repita los pasos (1)~(3)

3. Generador

En principio, el generador no tiene un modelo específico, siempre y cuando pueda generar un modelo de imagen, pero considerando el entrenamiento del modelo, generalmente se selecciona la red neuronal porque se puede entrenar junto con el discriminador. El generador es responsable de generar un par de imágenes. Por supuesto, la imagen en este momento es ruido, similar a la imagen de abajo, y no es nada si miras de cerca, pero esto no es importante, porque en GAN, ¡el generador no necesita ningún dato real ! , Sí, has leído bien, no importa qué tipo de datos genere, su maestro discriminador le dirá si la imagen es verdadera o no. En otras palabras, la imagen de abajo es demasiado falsa y el maestro puede decirlo de un vistazo.
inserte la descripción de la imagen aquí

Figura 2 La imagen generada por el generador

El código del generador es el siguiente:

import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np
import tqdm
from IPython.display import clear_output

L = keras.layers
LATENT_DIM = 100  # 潜在空间的维度
IMAGE_SHAPE = (28, 28, 1)  # 输出图像的尺寸

# 生成器
generate_net = [
    L.Input(shape=(LATENT_DIM, )),
    L.Dense(256),
    L.LeakyReLU(alpha=0.2),
    L.BatchNormalization(momentum=0.8),
    L.Dense(512),
    L.LeakyReLU(alpha=0.2),
    L.BatchNormalization(momentum=0.8),
    L.Dense(1024),
    L.LeakyReLU(alpha=0.2),
    L.BatchNormalization(momentum=0.8),
    L.Dense(np.prod(IMAGE_SHAPE), activation='tanh'),
    L.Reshape(IMAGE_SHAPE)
]
generate = keras.models.Sequential(generate_net)
generate.summary()

Nota
1. Usamos la función de activación LeakyReLU, que es una función de activación de uso común en GAN.
2. La normalización por lotes (BN), la normalización de instancias (IN) y la normalización espectral (SN) se
usan comúnmente para la normalización. 3. La última función de activación del generador generalmente usa la función tanh

4. Discriminador

El papel del discriminador en GAN es juzgar el nivel de los datos generados. Para juzgar si es verdadero o falso, primero es necesario entrenar al discriminador. De manera similar, si desea predecir los precios de la vivienda, primero debe aprender (ajustar) los datos del precio de la vivienda. Por lo tanto, parte del entrenamiento del discriminador en GAN son datos reales. Esta parte tiene en cuenta la conveniencia de que todos descarguen y copien el código de reproducción. Usamos el conjunto de datos MNIST. El conjunto de datos MNIST se puede descargar directamente a través de tensorflow , que es más conveniente. No es suficiente tener datos reales, luego debe alimentar los datos falsos al modelo, de lo contrario, ¿cómo aprender lo verdadero y lo falso? ¿De dónde provienen los datos falsos? ¡bien! Generador, nuestro generador solo puede generar datos falsos. De esta forma de pensar, podemos construir nuestro discriminador.
conjunto de datos mnist

Figura 2 Conjunto de datos MNIST

El código para el discriminador es el siguiente:

# 判别器
discriminator_net = [
    L.Input(shape=IMAGE_SHAPE),
    L.Flatten(),
    L.Dense(512),
    L.LeakyReLU(alpha=0.2),
    L.Dense(256),
    L.LeakyReLU(alpha=0.2),
    L.Dense(1, activation='sigmoid')
]
discriminator = keras.models.Sequential(discriminator_net)
discriminator.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])
discriminator.summary()

5. Generación de un modelo adversarial

Ya hemos aprendido que GAN se compone de dos partes: generador y discriminador, y ya hemos construido el generador y el discriminador. La suma de ellos es GAN, pero la razón por la que los construimos por separado es porque el entrenamiento de GAN es un proceso de iteración continua. Necesitamos entrenar el generador y el discriminador por separado. El discriminador debe juzgar si el generador es bueno o no. En este momento, el peso del discriminador debe congelarse y solo el peso del generador debe actualizarse, porque el objetivo del generador es mejorar continuamente la capacidad de "falsificación".
El modelo de GAN es el siguiente

adversarial_net = generate_net + discriminator_net
# 冻结判别器的权重
for layer in discriminator_net:
    layer.trainable = False

adversarial = keras.models.Sequential(adversarial_net)
adversarial.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])
adversarial.summary()

Seis, entrenamiento GAN

El entrenamiento de GAN es un proceso de depuración y optimización continua, que requiere mucha experiencia. Se recomienda que algunos novatos vean algunos consejos sugeridos por otros blogueros al entrenar GAN: consejos para entrenar GAN

En primer lugar, es necesario entender que el entrenamiento de GAN es un proceso iterativo

Mencionamos anteriormente que el nivel inicial de falsificación del generador es basura. Simplemente damos un mapa de ruido aleatorio. Asignamos la etiqueta 0 a esta parte de los datos falsos, que representa datos falsos, y luego los combinamos con el conjunto de datos reales, es decir, MNIST, para formar un conjunto de datos de imagen de dos categorías para entrenar nuestro discriminador.

Después de entrenar el discriminador, comenzamos a entrenar el generador.

Esperamos generar datos muy realistas para el generador, pero la calidad del generador es buena o mala. En otras palabras, los datos generados deben ser juzgados como 1 por el discriminador, es decir, pueden ser falsificados. En términos generales, cuando se entrena el generador, en realidad es una red en serie de generador + discriminador, pero el peso del discriminador se congela, similar a solo entrenar el generador.

En este punto, usamos el generador para generar una imagen, pero debemos establecer la etiqueta de esta imagen en 1, lo que significa verdadero, ¡no se preocupe! ! ! , leíste bien, se le debe asignar una etiqueta de 1. Esta imagen es juzgada por nuestro discriminador y da como resultado 0, es decir, es falsa. De esta manera, se forma un error muy grande antes y después. Obviamente es una imagen falsa, pero el generador dice que es verdadera. En este momento, el generador ajustará desesperadamente los parámetros hasta que el discriminador juzgue que es verdadera. El término estándar para este proceso es retropropagación, ¡y los parámetros de la red del generador se actualizan enormemente! Cuando el generador posterior puede producir imágenes realistas, la retropropagación es un proceso de ajuste fino y optimización continua de los parámetros del generador.

El entrenamiento GAN se puede lograr entrenando alternativamente el discriminador y el generador

El código de entrenamiento es el siguiente:

# 数据可视化
def sample_images(batch):
    rows, columns = 3, 10
    sample_count = rows * columns
    plt.figure(figsize=(columns, rows))
    # 使用生成器生成图像
    noise = np.random.normal(0, 1, (sample_count, LATENT_DIM))
    gen_imgs = generate.predict(noise)
    # 生成器图像张量的范围从【-1,1】改为【0,1】
    gen_imgs = 0.5 * gen_imgs + 0.5

    index = 0
    for row in range(rows):
        for column in range(columns):
            image = np.reshape(gen_imgs[index], [28, 28])
            plt.subplot(rows, columns, index+1)
            plt.imshow(image, cmap='gray')
            plt.axis('off')
            index += 1
    plt.tight_layout()
    plt.show()
    return gen_imgs

# 训练
def train(batch=30000, batch_size=32):
    # 读取数据,无需标签
    (image_set, _), (_, _) = keras.datasets.mnist.load_data()
    # 数据归一化
    image_set = image_set / 127.5 - 1.
    # 数据格式转换
    image_set = image_set.reshape(len(image_set), 28, 28, 1)
    # 准备batch_size同样大小的真假标签
    valid = np.ones((batch_size))
    fake = np.zeros((batch_size))
    # 利用tqdm生成迭代器
    batch_list = tqdm.trange(batch)
    for batch in batch_list:
        #  生成器生成图像
        idx = np.random.randint(0, image_set.shape[0], batch_size)
        imgs = image_set[idx]

        # 生成噪声数据并作为生成器的输入
        noise = np.random.normal(0, 1, (batch_size, LATENT_DIM))
        # 使用生成器生成图像
        gen_imgs = generate.predict(noise)

        # 训练判别器
        d_state_real = discriminator.train_on_batch(imgs, valid)
        d_state_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_state = 0.5 * np.add(d_state_real, d_state_fake)

        # 训练生成器
        adv_state = adversarial.train_on_batch(noise, valid)

        # 更新进度条后缀文本,用于输出训练进度
        state = f"[D loss:{
      
      d_state[0]:.4f} acc: {
      
      d_state[1]:.4f}]" \
                f"[A loss:{
      
      adv_state[0]:.4f} acc: {
      
      adv_state[1]:.4f}"
        batch_list.set_postfix(state=state)
        if (batch + 1) % 50 == 0:
            clear_output(wait=True)
            _ = sample_images(batch)

train()

Durante el proceso de entrenamiento, configuramos cada 50 lotes para generar las imágenes generadas por el generador. La
imagen inicial . La imagen generada después de 2000
inserte la descripción de la imagen aquí
iteraciones .
inserte la descripción de la imagen aquí

inserte la descripción de la imagen aquí

inserte la descripción de la imagen aquí

inserte la descripción de la imagen aquí

7. Precauciones para GAN

GAN的训练过程是一个动态过程,每个批次是新的开始,不会有简单的梯度下降过程,而是一个不断对抗平衡的过程,类似于minimax,我们要的是最小的判别器损失,最大的生成器误差,因此GAN的训练需要一些技巧,比如
1、一开始无需分类精度很高的判别器
2、初始学习率要小,否则下降过快或者Max过大不利于GAN的拟合
3、生成器和迭代器无需训练相同次数,比如可以生成器训练1次,判别器训练5次
4、迭代次数需要不断微调,迭代次数过小有可能生成的图像效果一般,迭代次数过大也会导致生成的图像效果一般,很多人会疑问迭代次数过大为什么会导致生成的图像效果一般,因为判别器每次训练更新权重用到的是生成器生成的假数据和真实数据,后期生成器生成的数据已经非常逼真了,而判别器学习到仍然判定为假,因此反而会导致生成器又开始生成很假的数据,如该博主利用GAN生成动漫头像文章点击这,迭代200次的图片如下
inserte la descripción de la imagen aquí
当其迭代750次后出现了上面提到的问题,由于生成器生成了图像质量非常高,但判别器仍然判定为假,导致生成器开始产生反向作用,如下所示
inserte la descripción de la imagen aquí

8. Evaluación de GAN

La evaluación de GAN siempre ha sido un problema difícil. En los primeros días, las personas juzgaban la calidad de las imágenes generadas a simple vista, pero es innegable que este método de evaluación es obviamente defectuoso. Desde 2016, han surgido los métodos de evaluación de GAN. Actualmente, los más populares son:

1, puntuación de inicio

Inception Score (IS) mide la claridad y diversidad de las imágenes generadas por el modelo mediante el uso del modelo de clasificación de imágenes de Google inception Net. Cuanto mayor sea el Inception Score, mejor será el modelo.

2. Frechet Distancia de inicio

La distancia Frechet Inception (FID) evalúa la diferencia entre la muestra generada y la muestra real al comparar la diferencia entre las características abstractas de la muestra real y la muestra generada en el modelo Inception V3. Cuanto más pequeño es el FID, más pequeño es el modelo.

nueve, otro

Existen muchos tipos de redes generativas de confrontación, como GAN, ACGAN, DCGAN, Pix2Pix, etc.

Para ver mis otros blogs, haz clic aquí

Supongo que te gusta

Origin blog.csdn.net/JaysonWong/article/details/127092221
Recomendado
Clasificación