[Caso de combate de la red neuronal Pytorch] El modelo 33WGAN-gp genera datos de simulación Fashion-MNST

1 modelo WGAN-gp para generar datos de simulación descripción del caso

Utilizando el modelo WGAN-gp para simular la generación de datos Fashion-MNIST, se utilizará el modelo WGAN-gp, el modelo Deep Convolutional GAN ​​(DCGAN) y la tecnología de normalización de instancias.

1.1 Convolución completa en DCGAN

El modelo WGAN-gp se enfoca en la parte de entrenamiento del modelo GAN, mientras que DCGAN se refiere a una GAN que usa una red neuronal convolucional, que se enfoca en la parte estructural del modelo GAN, enfocándose en la técnica de reconstrucción usando convolución completa en DCGAN.

1.1.1 Principio e implementación de DCGAN

El principio de DCGAN es similar a GAN, pero la tecnología de convolución CNN se utiliza en la red de modo GAN. Al generar datos, el generador G usa la técnica de reconstrucción por deconvolución para reconstruir la imagen original, y el discriminador D usa la técnica de convolución para identificar las características de la imagen y luego emitir un juicio. Al mismo tiempo, la red neuronal convolucional en DCGAN cambia la estructura para mejorar la calidad de la muestra y la velocidad de convergencia.

  1. En la red G, se cancelan todas las capas de agrupación, se usa la convolución completa y se usa un paso mayor o igual a 2 para el muestreo ascendente, ReLU() se usa como función de activación y tanh() se usa en la última capa .
  2. En la red D, se usa la operación de convolución que agrega reducción de muestreo en lugar de agrupación, y LeakyReLU() generalmente se usa como la función de activación, que puede retener algunas características menos de 0.
  3. Por lo general, la normalización no se usa en la última capa de D y G, el propósito de esto es garantizar que el modelo pueda aprender la distribución correcta de los datos.

El modelo DCGAN puede aprender mejor la representación jerárquica de la imagen de entrada, especialmente en la parte del generador, tendrá un mejor efecto de simulación.En el entrenamiento, se utilizará el algoritmo de optimización de Adam.

1.1.2 Implementación de convolución completa

En PyTorch, la convolución completa se implementa a través de la interfaz de convolución transpuesta ConvTranspose2d(). Los parámetros de esta interfaz tienen el mismo significado que los parámetros de la función de convolución, y también hay implementaciones de convolución transpuesta 1D y 3D.


convTranspose2d (in_channels,out_channels,kernel_size,stride=1,padding=0,output_padding=0,groups=1,bias=τrue,dilation=1,padding_mode='zeros')

Primero se transpone el kernel de convolución y luego se implementa el proceso de convolución completo.El tamaño de salida es inverso al tamaño de salida de la operación de convolución.

1.2 Muestreo ascendente y descendente 

1.2.1 Breve descripción del muestreo ascendente y descendente

  • Upsampling es ampliar la imagen.
  • Downsampling es reducir el tamaño de la imagen.

Las operaciones de muestreo ascendente y descendente no aportan más información a la imagen, pero afectarán la calidad de la imagen.

En la operación del modelo de red convolucional profunda, las dimensiones de los datos de esta capa y las capas superior e inferior pueden coincidir mediante operaciones de muestreo ascendente y descendente.

1.2.2 El papel del muestreo ascendente y descendente


Los modelos de redes neuronales a menudo usan convolución estrecha o agrupación para reducir la muestra del modelo,

Los modelos de redes neuronales utilizan convoluciones transpuestas para aumentar la muestra del modelo.

1.2.3 A través de las funciones de convolución y convolución completa, se procesa el método de submuestreo y se utiliza el método de sobremuestreo para realizar la operación de restauración.

from torch import nn
import torch

# 定义输入数据,3通道,尺寸为[12,12]
input = torch.randn(1,3,12,12)
# 输入和输出通道为3,卷积核为3,步长为2,进行下采样
downsample = nn.Conv2d(3,3,3,stride=2,padding=1)
h = downsample(input)
print(h.size())
# 输出结果:torch.SizeC[1,3,6,6]),尺寸变为[6,6]

# 输入和输出通道为3,卷积核为3,步长为2,进行上采样还原
upsample = nn.ConvTranspose2d(3,3,3,stride=2,padding=1)
output = upsample(h,output_size=input.size())
print(output.size())
# 输出结果:torch.Size([1,3,12,12]),尺寸变回[12,12]

1.3 Normalización de instancias

La normalización por lotes es para calcular la media y la desviación estándar de todos los píxeles en un lote de imágenes, y
la normalización por instancias es para normalizar una sola imagen, es decir, para calcular la media y la desviación estándar de todos los píxeles de una sola imagen.

1.3.1 Escenarios de uso de la normalización de instancias

En tareas generativas como los modelos de redes neuronales antagónicas y la transferencia de estilos, la normalización de instancias se usa a menudo en lugar de la normalización por lotes, porque la esencia de las tareas generativas es hacer coincidir la distribución de características de las muestras generadas con la distribución de características de las muestras de destino. Cada muestra de una tarea generativa tiene un estilo independiente y no debe tener demasiadas conexiones con otras muestras del lote. Por lo tanto, la normalización de instancias es adecuada para resolver tales problemas de distribución de muestras basadas en individuos.

1.3.2 Uso de la normalización de instancias

La interfaz de implementación de la normalización de instancias en PyTorch es InstanceNorm2d() en el módulo nn, así como la normalización de instancias 1D y 3D, que son similares a esta interfaz.

Instanceormd(num_features,eps=1e-5,momentum=0.1,affine=False,track_running_stats=False)

Solo concéntrese en el parámetro num_features. Este parámetro es el número de canales que deben pasarse en los datos de entrada. Otros parámetros tienen el mismo significado que los parámetros en BatchNorm2d(). Esta interfaz normalizará un solo dato de acuerdo con el canal, y la forma devuelta será la misma que la forma de entrada.

2 Ejemplo de escritura de código

2.1 Combate de código: Introducir módulos y cargar muestras----WGAN-gp-228.py (Parte 1)

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.autograd as autograd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# 1.1 引入模块并载入样本:定义基本函数,加载FashionMNIST数据集
def to_img(x):
    x = 0.5 * (x+1)
    x = x.clamp(0,1)
    x = x.view(x.size(0),1,28,28)
    return x

def imshow(img,filename = None):
    npimg = img.numpy()
    plt.axis('off')
    array = np.transpose(npimg,(1,2,0))
    if filename != None:
        matplotlib.image.imsave(filename,array)
    else:
        plt.imshow(array)
        # plt.savefig(filename) # 保存图片 注释掉,因为会报错,暂时不知道什么原因 2022.3.26 15:20
        plt.show()

img_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ]
)

data_dir = './fashion_mnist'

train_dataset = torchvision.datasets.FashionMNIST(data_dir,train=True,transform=img_transform,download=True)
train_loader = DataLoader(train_dataset,batch_size=1024,shuffle=True)
# 测试数据集
val_dataset = torchvision.datasets.FashionMNIST(data_dir,train=False,transform=img_transform)
test_loader = DataLoader(val_dataset,batch_size=10,shuffle=False)
# 指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

2.2 Código de combate: Implementando generador y discriminador----WGAN-gp-228.py (Parte 2)

# 1.2 实现生成器和判别器 :因为复杂部分都放在loss值的计算方面了,所以生成器和判别器就会简单一些。
# 生成器和判别器各自有两个卷积和两个全连接层。生成器最终输出与输入图片相同维度的数据作为模拟样本。
# 判别器的输出不需要有激活函数,并且输出维度为1的数值用来表示结果。
# 在GAN模型中,因判别器的输入则是具体的样本数据,要区分每个数据的分布特征,所以判别器使用实例归一化,
class WGAN_D(nn.Module): # 定义判别器类D :有两个卷积和两个全连接层
    def __init__(self,inputch=1):
        super(WGAN_D, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(inputch,64,4,2,1), # 输出形状为[batch,64,28,28]
            nn.LeakyReLU(0.2,True),
            nn.InstanceNorm2d(64,affine=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64,128,4,2,1),# 输出形状为[batch,64,14,14]
            nn.LeakyReLU(0.2,True),
            nn.InstanceNorm2d(128,affine=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(128*7*7,1024),
            nn.LeakyReLU(0.2,True)
        )
        self.fc2 = nn.Sequential(
            nn.InstanceNorm1d(1,affine=True),
            nn.Flatten(),
            nn.Linear(1024,1)
        )
    def forward(self,x,*arg): # 正向传播
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        x = x.reshape(x.size(0),1,-1)
        x = self.fc2(x)
        return x.view(-1,1).squeeze(1)

# 在GAN模型中,因生成器的初始输入是随机值,所以生成器使用批量归一化。
class WGAN_G(nn.Module): # 定义生成器类G:有两个卷积和两个全连接层
    def __init__(self,input_size,input_n=1):
        super(WGAN_G, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(input_size * input_n,1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024,7*7*128),
            nn.ReLU(True),
            nn.BatchNorm1d(7*7*128)
        )
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(128,64,4,2,padding=1,bias=False), # 输出形状为[batch,64,14,14]
            nn.ReLU(True),
            nn.BatchNorm2d(64)
        )
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose2d(64,1,4,2,padding=1,bias=False), # 输出形状为[batch,64,28,28]
            nn.Tanh()
        )
    def forward(self,x,*arg): # 正向传播
        x = self.fc1(x)
        x = self.fc2(x)
        x = x.view(x.size(0),128,7,7)
        x = self.upsample1(x)
        img = self.upsample2(x)
        return img

2.3 Combate de código: defina la función para completar el término de penalización de gradiente----WGAN-gp-228.py (Parte 3)

# 1.3 定义函数compute_gradient_penalty()完成梯度惩罚项
# 惩罚项的样本X_inter由一部分Pg分布和一部分Pr分布组成,同时对D(X_inter)求梯度,并计算梯度与1的平方差,最终得到gradient_penalties
lambda_gp = 10
# 计算梯度惩罚项
def compute_gradient_penalty(D,real_samples,fake_samples,y_one_hot):
    # 获取一个随机数,作为真假样本的采样比例
    eps = torch.FloatTensor(real_samples.size(0),1,1,1).uniform_(0,1).to(device)
    # 按照eps比例生成真假样本采样值X_inter
    X_inter = (eps * real_samples + ((1-eps)*fake_samples)).requires_grad_(True)
    d_interpolates = D(X_inter,y_one_hot)
    fake = torch.full((real_samples.size(0),),1,device=device) # 计算梯度输出的掩码,在本例中需要对所有梯度进行计算,故需要按照样本个数生成全为1的张量。
    # 求梯度
    gradients = autograd.grad(outputs=d_interpolates, # 输出值outputs,传入计算过的张量结果
                              inputs=X_inter,# 待求梯度的输入值inputs,传入可导的张量,即requires_grad=True
                              grad_outputs=fake, # 传出梯度的掩码grad_outputs,使用1和0组成的掩码,在计算梯度之后,会将求导结果与该掩码进行相乘得到最终结果。
                              create_graph=True,
                              retain_graph=True,
                              only_inputs=True
                              )[0]
    gradients = gradients.view(gradients.size(0),-1)
    gradient_penaltys = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
    return gradient_penaltys

2.4 Código de combate: definir la función de entrenamiento del modelo----WGAN-gp-228.py (Parte 4)

# 1.4 定义模型的训练函数
# 定义函数train(),实现模型的训练过程。
# 在函数train()中,按照对抗神经网络专题(一)中的式(8-24)实现模型的损失函数。
# 判别器的loss为D(fake_samples)-D(real_samples)再加上联合分布样本的梯度惩罚项gradient_penalties,其中fake_samples为生成的模拟数据,real_Samples为真实数据,
# 生成器的loss为-D(fake_samples)。
def train(D,G,outdir,z_dimension,num_epochs=30):
    d_optimizer = torch.optim.Adam(D.parameters(),lr=0.001) # 定义优化器
    g_optimizer = torch.optim.Adam(G.parameters(),lr=0.001)

    os.makedirs(outdir,exist_ok=True) # 创建输出文件夹
    # 在函数train()中,判别器和生成器是分开训练的。让判别器学习的次数多一些,判别器每训练5次,生成器优化1次。
    # WGAN_gp不会因为判别器准确率太高而引起生成器梯度消失的问题,所以好的判别器会让生成器有更好的模拟效果。
    for epoch in range(num_epochs):
        for i,(img,lab) in enumerate(train_loader):
            num_img = img.size(0)

            # 训练判别器
            real_img = img.to(device)
            y_one_hot = torch.zeros(lab.shape[0],10).scatter_(1,lab.view(lab.shape[0],1),1).to(device)
            for ii in range(5): # 循环训练5次
                d_optimizer.zero_grad() # 梯度清零
                # 对real_img进行判别
                real_out = D(real_img,y_one_hot)
                # 生成随机值
                z = torch.randn(num_img,z_dimension).to(device)
                fake_img = G(z,y_one_hot) # 生成fake_img
                fake_out = D(fake_img,y_one_hot) # 对fake_img进行判别
                # 计算梯度惩罚项
                gradient_penalty = compute_gradient_penalty(D,real_img.data,fake_img.data,y_one_hot)
                # 计算判别器的loss
                d_loss = -torch.mean(real_out)+torch.mean(fake_out)+gradient_penalty
                d_loss.backward()
                d_optimizer.step()

            # 训练生成器
            for ii in range(1): # 训练一次
                g_optimizer.zero_grad() # 梯度清0
                z = torch.randn(num_img,z_dimension).to(device)
                fake_img = G(z,y_one_hot)
                fake_out = D(fake_img,y_one_hot)
                g_loss =  -torch.mean(fake_out)
                g_loss.backward()
                g_optimizer.step()
        # 输出可视化结果,并将生成的结果以图片的形式存储在硬盘中
        fake_images = to_img(fake_img.cpu().data)
        real_images = to_img(real_img.cpu().data)
        rel = torch.cat([to_img(real_images[:10]), fake_images[:10]], axis=0)
        imshow(torchvision.utils.make_grid(rel, nrow=10),os.path.join(outdir, 'fake_images-{}.png'.format(epoch + 1)))
        # 输出训练结果
        print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epochs, d_loss.data, g_loss.data,real_out.data.mean(), fake_out.data.mean()))
    # 保存训练模型
    torch.save(G.state_dict(), os.path.join(outdir, 'generator.pth'))
    torch.save(D.state_dict(), os.path.join(outdir, 'discriminator.pth'))

2.5 Combate de código: ahora visualiza los resultados del modelo----WGAN-gp-228.py (Parte 5)

# 1.5 定义函数,实现可视化模型结果:获取一部分测试数据,显示由模型生成的模拟数据。
def displayAndTest(D,G,z_dimension):    # 可视化结果
    sample = iter(test_loader)
    images, labels = sample.next()
    y_one_hot = torch.zeros(labels.shape[0], 10).scatter_(1,labels.view(labels.shape[0], 1), 1).to(device)
    num_img = images.size(0) # 获取样本个数
    with torch.no_grad():
        z = torch.randn(num_img, z_dimension).to(device) # 生成随机数
        fake_img = G(z, y_one_hot)
    fake_images = to_img(fake_img.cpu().data) # 生成模拟样本
    rel = torch.cat([to_img(images[:10]), fake_images[:10]], axis=0)
    imshow(torchvision.utils.make_grid(rel, nrow=10))
    print(labels[:10])

2.6 Código de combate: llamando a la función y entrenando el modelo----WGAN-gp-228.py (Parte 6)

# 1.6 调用函数并训练模型:实例化判别器和生成器模型,并调用函数进行训练
if __name__ == '__main__':
    z_dimension = 40  # 设置输入随机数的维度

    D = WGAN_D().to(device)  # 实例化判别器
    G = WGAN_G(z_dimension).to(device)  # 实例化生成器
    train(D, G, './w_img', z_dimension) # 训练模型
    displayAndTest(D, G, z_dimension) # 输出可视化

Se puede ver que el valor absoluto de g_loss está disminuyendo gradualmente y el valor absoluto de d_loss está aumentando gradualmente. Esto indica que las cuasi-muestras generadas son de calidad creciente. En la carpeta w_img de la ruta local, puede ver 30 imágenes, 3 de las cuales se enumeran aquí.


Muestra 3 imágenes (una por cada dos filas) durante el proceso de entrenamiento, que son los resultados de salida después de las iteraciones 1, 18 y 30 del entrenamiento, respectivamente. La primera fila de cada imagen son los datos de muestra y la segunda fila son los datos simulados generados.

Se concluye que bajo los estrictos requisitos del discriminador de WGAN-gp, los datos simulados generados por el generador son cada vez más realistas. Se puede ver a partir de los resultados generados que los datos de la muestra no corresponden a la categoría de datos simulados generados, porque no tenemos la información de la categoría generada agregada. Los efectos específicos de la categoría se pueden lograr utilizando GAN condicionales.
 

 3 Resumen de código (WGAN-gp-228.py)

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.autograd as autograd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# 1.1 引入模块并载入样本:定义基本函数,加载FashionMNIST数据集
def to_img(x):
    x = 0.5 * (x+1)
    x = x.clamp(0,1)
    x = x.view(x.size(0),1,28,28)
    return x

def imshow(img,filename = None):
    npimg = img.numpy()
    plt.axis('off')
    array = np.transpose(npimg,(1,2,0))
    if filename != None:
        matplotlib.image.imsave(filename,array)
    else:
        plt.imshow(array)
        # plt.savefig(filename) # 保存图片 注释掉,因为会报错,暂时不知道什么原因 2022.3.26 15:20
        plt.show()

img_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ]
)

data_dir = './fashion_mnist'

train_dataset = torchvision.datasets.FashionMNIST(data_dir,train=True,transform=img_transform,download=True)
train_loader = DataLoader(train_dataset,batch_size=1024,shuffle=True)
# 测试数据集
val_dataset = torchvision.datasets.FashionMNIST(data_dir,train=False,transform=img_transform)
test_loader = DataLoader(val_dataset,batch_size=10,shuffle=False)
# 指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# 1.2 实现生成器和判别器 :因为复杂部分都放在loss值的计算方面了,所以生成器和判别器就会简单一些。
# 生成器和判别器各自有两个卷积和两个全连接层。生成器最终输出与输入图片相同维度的数据作为模拟样本。
# 判别器的输出不需要有激活函数,并且输出维度为1的数值用来表示结果。
# 在GAN模型中,因判别器的输入则是具体的样本数据,要区分每个数据的分布特征,所以判别器使用实例归一化,
class WGAN_D(nn.Module): # 定义判别器类D :有两个卷积和两个全连接层
    def __init__(self,inputch=1):
        super(WGAN_D, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(inputch,64,4,2,1), # 输出形状为[batch,64,28,28]
            nn.LeakyReLU(0.2,True),
            nn.InstanceNorm2d(64,affine=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64,128,4,2,1),# 输出形状为[batch,64,14,14]
            nn.LeakyReLU(0.2,True),
            nn.InstanceNorm2d(128,affine=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(128*7*7,1024),
            nn.LeakyReLU(0.2,True)
        )
        self.fc2 = nn.Sequential(
            nn.InstanceNorm1d(1,affine=True),
            nn.Flatten(),
            nn.Linear(1024,1)
        )
    def forward(self,x,*arg): # 正向传播
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        x = x.reshape(x.size(0),1,-1)
        x = self.fc2(x)
        return x.view(-1,1).squeeze(1)

# 在GAN模型中,因生成器的初始输入是随机值,所以生成器使用批量归一化。
class WGAN_G(nn.Module): # 定义生成器类G:有两个卷积和两个全连接层
    def __init__(self,input_size,input_n=1):
        super(WGAN_G, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(input_size * input_n,1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024,7*7*128),
            nn.ReLU(True),
            nn.BatchNorm1d(7*7*128)
        )
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(128,64,4,2,padding=1,bias=False), # 输出形状为[batch,64,14,14]
            nn.ReLU(True),
            nn.BatchNorm2d(64)
        )
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose2d(64,1,4,2,padding=1,bias=False), # 输出形状为[batch,64,28,28]
            nn.Tanh()
        )
    def forward(self,x,*arg): # 正向传播
        x = self.fc1(x)
        x = self.fc2(x)
        x = x.view(x.size(0),128,7,7)
        x = self.upsample1(x)
        img = self.upsample2(x)
        return img

# 1.3 定义函数compute_gradient_penalty()完成梯度惩罚项
# 惩罚项的样本X_inter由一部分Pg分布和一部分Pr分布组成,同时对D(X_inter)求梯度,并计算梯度与1的平方差,最终得到gradient_penalties
lambda_gp = 10
# 计算梯度惩罚项
def compute_gradient_penalty(D,real_samples,fake_samples,y_one_hot):
    # 获取一个随机数,作为真假样本的采样比例
    eps = torch.FloatTensor(real_samples.size(0),1,1,1).uniform_(0,1).to(device)
    # 按照eps比例生成真假样本采样值X_inter
    X_inter = (eps * real_samples + ((1-eps)*fake_samples)).requires_grad_(True)
    d_interpolates = D(X_inter,y_one_hot)
    fake = torch.full((real_samples.size(0),),1,device=device) # 计算梯度输出的掩码,在本例中需要对所有梯度进行计算,故需要按照样本个数生成全为1的张量。
    # 求梯度
    gradients = autograd.grad(outputs=d_interpolates, # 输出值outputs,传入计算过的张量结果
                              inputs=X_inter,# 待求梯度的输入值inputs,传入可导的张量,即requires_grad=True
                              grad_outputs=fake, # 传出梯度的掩码grad_outputs,使用1和0组成的掩码,在计算梯度之后,会将求导结果与该掩码进行相乘得到最终结果。
                              create_graph=True,
                              retain_graph=True,
                              only_inputs=True
                              )[0]
    gradients = gradients.view(gradients.size(0),-1)
    gradient_penaltys = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
    return gradient_penaltys

# 1.4 定义模型的训练函数
# 定义函数train(),实现模型的训练过程。
# 在函数train()中,按照对抗神经网络专题(一)中的式(8-24)实现模型的损失函数。
# 判别器的loss为D(fake_samples)-D(real_samples)再加上联合分布样本的梯度惩罚项gradient_penalties,其中fake_samples为生成的模拟数据,real_Samples为真实数据,
# 生成器的loss为-D(fake_samples)。
def train(D,G,outdir,z_dimension,num_epochs=30):
    d_optimizer = torch.optim.Adam(D.parameters(),lr=0.001) # 定义优化器
    g_optimizer = torch.optim.Adam(G.parameters(),lr=0.001)

    os.makedirs(outdir,exist_ok=True) # 创建输出文件夹
    # 在函数train()中,判别器和生成器是分开训练的。让判别器学习的次数多一些,判别器每训练5次,生成器优化1次。
    # WGAN_gp不会因为判别器准确率太高而引起生成器梯度消失的问题,所以好的判别器会让生成器有更好的模拟效果。
    for epoch in range(num_epochs):
        for i,(img,lab) in enumerate(train_loader):
            num_img = img.size(0)

            # 训练判别器
            real_img = img.to(device)
            y_one_hot = torch.zeros(lab.shape[0],10).scatter_(1,lab.view(lab.shape[0],1),1).to(device)
            for ii in range(5): # 循环训练5次
                d_optimizer.zero_grad() # 梯度清零
                # 对real_img进行判别
                real_out = D(real_img,y_one_hot)
                # 生成随机值
                z = torch.randn(num_img,z_dimension).to(device)
                fake_img = G(z,y_one_hot) # 生成fake_img
                fake_out = D(fake_img,y_one_hot) # 对fake_img进行判别
                # 计算梯度惩罚项
                gradient_penalty = compute_gradient_penalty(D,real_img.data,fake_img.data,y_one_hot)
                # 计算判别器的loss
                d_loss = -torch.mean(real_out)+torch.mean(fake_out)+gradient_penalty
                d_loss.backward()
                d_optimizer.step()

            # 训练生成器
            for ii in range(1): # 训练一次
                g_optimizer.zero_grad() # 梯度清0
                z = torch.randn(num_img,z_dimension).to(device)
                fake_img = G(z,y_one_hot)
                fake_out = D(fake_img,y_one_hot)
                g_loss =  -torch.mean(fake_out)
                g_loss.backward()
                g_optimizer.step()
        # 输出可视化结果,并将生成的结果以图片的形式存储在硬盘中
        fake_images = to_img(fake_img.cpu().data)
        real_images = to_img(real_img.cpu().data)
        rel = torch.cat([to_img(real_images[:10]), fake_images[:10]], axis=0)
        imshow(torchvision.utils.make_grid(rel, nrow=10),os.path.join(outdir, 'fake_images-{}.png'.format(epoch + 1)))
        # 输出训练结果
        print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epochs, d_loss.data, g_loss.data,real_out.data.mean(), fake_out.data.mean()))
    # 保存训练模型
    torch.save(G.state_dict(), os.path.join(outdir, 'generator.pth'))
    torch.save(D.state_dict(), os.path.join(outdir, 'discriminator.pth'))

# 1.5 定义函数,实现可视化模型结果:获取一部分测试数据,显示由模型生成的模拟数据。
def displayAndTest(D,G,z_dimension):    # 可视化结果
    sample = iter(test_loader)
    images, labels = sample.next()
    y_one_hot = torch.zeros(labels.shape[0], 10).scatter_(1,labels.view(labels.shape[0], 1), 1).to(device)
    num_img = images.size(0) # 获取样本个数
    with torch.no_grad():
        z = torch.randn(num_img, z_dimension).to(device) # 生成随机数
        fake_img = G(z, y_one_hot)
    fake_images = to_img(fake_img.cpu().data) # 生成模拟样本
    rel = torch.cat([to_img(images[:10]), fake_images[:10]], axis=0)
    imshow(torchvision.utils.make_grid(rel, nrow=10))
    print(labels[:10])

# 1.6 调用函数并训练模型:实例化判别器和生成器模型,并调用函数进行训练
if __name__ == '__main__':
    z_dimension = 40  # 设置输入随机数的维度

    D = WGAN_D().to(device)  # 实例化判别器
    G = WGAN_G(z_dimension).to(device)  # 实例化生成器
    train(D, G, './w_img', z_dimension) # 训练模型
    displayAndTest(D, G, z_dimension) # 输出可视化

Supongo que te gusta

Origin blog.csdn.net/qq_39237205/article/details/123756676
Recomendado
Clasificación