【Pytorchニューラルネットワーク戦闘事例】33WGAN-gpモデルがFashion-MNSTシミュレーションデータを生成

1シミュレーションデータを生成するためのWGAN-gpモデルケースの説明

WGAN-gpモデルを使用してFashion-MNISTデータの生成をシミュレートし、WGAN-gpモデル、Deep Convolutional GAN(DCGAN)モデル、およびインスタンス正規化テクノロジーを使用します。

1.1DCGANでの完全な畳み込み

WGAN-gpモデルはGANモデルのトレーニング部分に焦点を当てていますが、DCGANは畳み込みニューラルネットワークを使用したGANを指し、GANモデルの構造部分に焦点を当て、DCGANでの完全な畳み込みを使用した再構成手法に焦点を当てています。

1.1.1DCGANの原則と実装

DCGANの原理はGANに似ていますが、CNN畳み込み技術がGANモードネットワークで使用されます。データを生成する場合、ジェネレータGはデコンボリューション再構成手法を使用して元の画像を再構成し、弁別器Dは畳み込み手法を使用して画像の特徴を識別して判断します。同時に、DCGANの畳み込みニューラルネットワークは構造を変更して、サンプルの品質と収束速度を向上させます。

  1. Gネットワ​​ークでは、すべてのプーリングレイヤーがキャンセルされ、完全な畳み込みが使用され、2以上のストライドがアップサンプリングに使用され、ReLU()が活性化関数として使用され、tanh()が最後のレイヤーで使用されます。 。
  2. Dネットワークでは、プールの代わりにダウンサンプリングを追加する畳み込み演算が使用され、LeakyReLU()は通常、0未満の一部の機能を保持できる活性化関数として使用されます。
  3. 通常、正規化はDとGの最後の層では使用されません。これの目的は、モデルがデータの正しい分布を学習できるようにすることです。

DCGANモデルは、特にジェネレーター部分で、入力画像の階層表現をより適切に学習でき、シミュレーション効果が向上します。トレーニングでは、Adam最適化アルゴリズムが使用されます。

1.1.2完全な畳み込みの実装

PyTorchでは、完全な畳み込みは、転置された畳み込みインターフェイスConvTranspose2d()を介して実装されます。このインターフェースのパラメーターは、畳み込み関数のパラメーターと同じ意味を持ち、1Dおよび3D転置畳み込みの実装もあります。


convTranspose2d (in_channels,out_channels,kernel_size,stride=1,padding=0,output_padding=0,groups=1,bias=τrue,dilation=1,padding_mode='zeros')

畳み込みカーネルが最初に転置され、次に完全な畳み込みプロセスが実装されます。出力サイズは、畳み込み演算によって出力されるサイズと逆になります。

1.2アップサンプリングとダウンサンプリング 

1.2.1アップサンプリングとダウンサンプリングの簡単な説明

  • アップサンプリングは画像を拡大することです。
  • ダウンサンプリングは、画像のサイズを縮小します。

アップサンプリングおよびダウンサンプリング操作は、画像により多くの情報をもたらすことはありませんが、画質に影響を与えます。

深層畳み込みネットワークモデルの操作では、この層と上層および下層のデータの次元を、アップサンプリングおよびダウンサンプリング操作によって一致させることができます。

1.2.2アップサンプリングとダウンサンプリングの役割


ニューラルネットワークモデルは、モデルをダウンサンプリングするために狭い畳み込みまたはプーリングを使用することがよくあります。

ニューラルネットワークモデルは、転置された畳み込みを使用してモデルをアップサンプリングします。

1.2.3畳み込み関数と完全畳み込み関数により、ダウンサンプリング方式が処理され、アップサンプリング方式を使用して復元操作が実行されます。

from torch import nn
import torch

# 定义输入数据,3通道,尺寸为[12,12]
input = torch.randn(1,3,12,12)
# 输入和输出通道为3,卷积核为3,步长为2,进行下采样
downsample = nn.Conv2d(3,3,3,stride=2,padding=1)
h = downsample(input)
print(h.size())
# 输出结果:torch.SizeC[1,3,6,6]),尺寸变为[6,6]

# 输入和输出通道为3,卷积核为3,步长为2,进行上采样还原
upsample = nn.ConvTranspose2d(3,3,3,stride=2,padding=1)
output = upsample(h,output_size=input.size())
print(output.size())
# 输出结果:torch.Size([1,3,12,12]),尺寸变回[12,12]

1.3インスタンスの正規化

バッチ正規化は、画像のバッチ内のすべてのピクセルの平均と標準偏差を計算することであり、
インスタンスの正規化は、単一の画像を正規化することです。つまり、単一の画像のすべてのピクセルの平均と標準偏差を計算することです。

1.3.1インスタンス正規化の使用シナリオ

敵対的なニューラルネットワークモデルやスタイル転送などの生成タスクでは、生成タスクの本質が生成されたサンプルの特徴分布をターゲットサンプルの特徴分布と一致させることであるため、バッチ正規化の代わりにインスタンス正規化がよく使用されます。生成タスクの各サンプルには独立したスタイルがあり、バッチ内の他のサンプルとの接続が多すぎないようにする必要があります。したがって、インスタンスの正規化は、このような個人ベースのサンプル分布の問題を解決するのに適しています。

1.3.2インスタンスの正規化の使用

PyTorchでのインスタンス正規化の実装インターフェイスは、nnモジュールの下のInstanceNorm2d()と、このインターフェイスに類似した1Dおよび3Dインスタンス正規化です。

Instanceormd(num_features,eps=1e-5,momentum=0.1,affine=False,track_running_stats=False)

パラメータnum_featuresのみに注目してください。このパラメータは、入力データで渡す必要のあるチャネルの数です。他のパラメータは、BatchNorm2d()のパラメータと同じ意味です。このインターフェイスは、チャネルに従って単一のデータを正規化します。返される形状は入力形状と同じになります。

2サンプルコードの記述

2.1コードコンバット:モジュールの導入とサンプルのロード----WGAN-gp-228.py(パート1)

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.autograd as autograd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# 1.1 引入模块并载入样本:定义基本函数,加载FashionMNIST数据集
def to_img(x):
    x = 0.5 * (x+1)
    x = x.clamp(0,1)
    x = x.view(x.size(0),1,28,28)
    return x

def imshow(img,filename = None):
    npimg = img.numpy()
    plt.axis('off')
    array = np.transpose(npimg,(1,2,0))
    if filename != None:
        matplotlib.image.imsave(filename,array)
    else:
        plt.imshow(array)
        # plt.savefig(filename) # 保存图片 注释掉,因为会报错,暂时不知道什么原因 2022.3.26 15:20
        plt.show()

img_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ]
)

data_dir = './fashion_mnist'

train_dataset = torchvision.datasets.FashionMNIST(data_dir,train=True,transform=img_transform,download=True)
train_loader = DataLoader(train_dataset,batch_size=1024,shuffle=True)
# 测试数据集
val_dataset = torchvision.datasets.FashionMNIST(data_dir,train=False,transform=img_transform)
test_loader = DataLoader(val_dataset,batch_size=10,shuffle=False)
# 指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

2.2コードコンバット:ジェネレーターとディスクリミネーターの実装----WGAN-gp-228.py(パート2)

# 1.2 实现生成器和判别器 :因为复杂部分都放在loss值的计算方面了,所以生成器和判别器就会简单一些。
# 生成器和判别器各自有两个卷积和两个全连接层。生成器最终输出与输入图片相同维度的数据作为模拟样本。
# 判别器的输出不需要有激活函数,并且输出维度为1的数值用来表示结果。
# 在GAN模型中,因判别器的输入则是具体的样本数据,要区分每个数据的分布特征,所以判别器使用实例归一化,
class WGAN_D(nn.Module): # 定义判别器类D :有两个卷积和两个全连接层
    def __init__(self,inputch=1):
        super(WGAN_D, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(inputch,64,4,2,1), # 输出形状为[batch,64,28,28]
            nn.LeakyReLU(0.2,True),
            nn.InstanceNorm2d(64,affine=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64,128,4,2,1),# 输出形状为[batch,64,14,14]
            nn.LeakyReLU(0.2,True),
            nn.InstanceNorm2d(128,affine=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(128*7*7,1024),
            nn.LeakyReLU(0.2,True)
        )
        self.fc2 = nn.Sequential(
            nn.InstanceNorm1d(1,affine=True),
            nn.Flatten(),
            nn.Linear(1024,1)
        )
    def forward(self,x,*arg): # 正向传播
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        x = x.reshape(x.size(0),1,-1)
        x = self.fc2(x)
        return x.view(-1,1).squeeze(1)

# 在GAN模型中,因生成器的初始输入是随机值,所以生成器使用批量归一化。
class WGAN_G(nn.Module): # 定义生成器类G:有两个卷积和两个全连接层
    def __init__(self,input_size,input_n=1):
        super(WGAN_G, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(input_size * input_n,1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024,7*7*128),
            nn.ReLU(True),
            nn.BatchNorm1d(7*7*128)
        )
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(128,64,4,2,padding=1,bias=False), # 输出形状为[batch,64,14,14]
            nn.ReLU(True),
            nn.BatchNorm2d(64)
        )
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose2d(64,1,4,2,padding=1,bias=False), # 输出形状为[batch,64,28,28]
            nn.Tanh()
        )
    def forward(self,x,*arg): # 正向传播
        x = self.fc1(x)
        x = self.fc2(x)
        x = x.view(x.size(0),128,7,7)
        x = self.upsample1(x)
        img = self.upsample2(x)
        return img

2.3コードコンバット:勾配ペナルティ項を完了する関数を定義する----WGAN-gp-228.py(パート3)

# 1.3 定义函数compute_gradient_penalty()完成梯度惩罚项
# 惩罚项的样本X_inter由一部分Pg分布和一部分Pr分布组成,同时对D(X_inter)求梯度,并计算梯度与1的平方差,最终得到gradient_penalties
lambda_gp = 10
# 计算梯度惩罚项
def compute_gradient_penalty(D,real_samples,fake_samples,y_one_hot):
    # 获取一个随机数,作为真假样本的采样比例
    eps = torch.FloatTensor(real_samples.size(0),1,1,1).uniform_(0,1).to(device)
    # 按照eps比例生成真假样本采样值X_inter
    X_inter = (eps * real_samples + ((1-eps)*fake_samples)).requires_grad_(True)
    d_interpolates = D(X_inter,y_one_hot)
    fake = torch.full((real_samples.size(0),),1,device=device) # 计算梯度输出的掩码,在本例中需要对所有梯度进行计算,故需要按照样本个数生成全为1的张量。
    # 求梯度
    gradients = autograd.grad(outputs=d_interpolates, # 输出值outputs,传入计算过的张量结果
                              inputs=X_inter,# 待求梯度的输入值inputs,传入可导的张量,即requires_grad=True
                              grad_outputs=fake, # 传出梯度的掩码grad_outputs,使用1和0组成的掩码,在计算梯度之后,会将求导结果与该掩码进行相乘得到最终结果。
                              create_graph=True,
                              retain_graph=True,
                              only_inputs=True
                              )[0]
    gradients = gradients.view(gradients.size(0),-1)
    gradient_penaltys = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
    return gradient_penaltys

2.4コードコンバット:モデルのトレーニング機能を定義する----WGAN-gp-228.py(パート4)

# 1.4 定义模型的训练函数
# 定义函数train(),实现模型的训练过程。
# 在函数train()中,按照对抗神经网络专题(一)中的式(8-24)实现模型的损失函数。
# 判别器的loss为D(fake_samples)-D(real_samples)再加上联合分布样本的梯度惩罚项gradient_penalties,其中fake_samples为生成的模拟数据,real_Samples为真实数据,
# 生成器的loss为-D(fake_samples)。
def train(D,G,outdir,z_dimension,num_epochs=30):
    d_optimizer = torch.optim.Adam(D.parameters(),lr=0.001) # 定义优化器
    g_optimizer = torch.optim.Adam(G.parameters(),lr=0.001)

    os.makedirs(outdir,exist_ok=True) # 创建输出文件夹
    # 在函数train()中,判别器和生成器是分开训练的。让判别器学习的次数多一些,判别器每训练5次,生成器优化1次。
    # WGAN_gp不会因为判别器准确率太高而引起生成器梯度消失的问题,所以好的判别器会让生成器有更好的模拟效果。
    for epoch in range(num_epochs):
        for i,(img,lab) in enumerate(train_loader):
            num_img = img.size(0)

            # 训练判别器
            real_img = img.to(device)
            y_one_hot = torch.zeros(lab.shape[0],10).scatter_(1,lab.view(lab.shape[0],1),1).to(device)
            for ii in range(5): # 循环训练5次
                d_optimizer.zero_grad() # 梯度清零
                # 对real_img进行判别
                real_out = D(real_img,y_one_hot)
                # 生成随机值
                z = torch.randn(num_img,z_dimension).to(device)
                fake_img = G(z,y_one_hot) # 生成fake_img
                fake_out = D(fake_img,y_one_hot) # 对fake_img进行判别
                # 计算梯度惩罚项
                gradient_penalty = compute_gradient_penalty(D,real_img.data,fake_img.data,y_one_hot)
                # 计算判别器的loss
                d_loss = -torch.mean(real_out)+torch.mean(fake_out)+gradient_penalty
                d_loss.backward()
                d_optimizer.step()

            # 训练生成器
            for ii in range(1): # 训练一次
                g_optimizer.zero_grad() # 梯度清0
                z = torch.randn(num_img,z_dimension).to(device)
                fake_img = G(z,y_one_hot)
                fake_out = D(fake_img,y_one_hot)
                g_loss =  -torch.mean(fake_out)
                g_loss.backward()
                g_optimizer.step()
        # 输出可视化结果,并将生成的结果以图片的形式存储在硬盘中
        fake_images = to_img(fake_img.cpu().data)
        real_images = to_img(real_img.cpu().data)
        rel = torch.cat([to_img(real_images[:10]), fake_images[:10]], axis=0)
        imshow(torchvision.utils.make_grid(rel, nrow=10),os.path.join(outdir, 'fake_images-{}.png'.format(epoch + 1)))
        # 输出训练结果
        print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epochs, d_loss.data, g_loss.data,real_out.data.mean(), fake_out.data.mean()))
    # 保存训练模型
    torch.save(G.state_dict(), os.path.join(outdir, 'generator.pth'))
    torch.save(D.state_dict(), os.path.join(outdir, 'discriminator.pth'))

2.5コードコンバット:モデルの結果を視覚化する----WGAN-gp-228.py(パート5)

# 1.5 定义函数,实现可视化模型结果:获取一部分测试数据,显示由模型生成的模拟数据。
def displayAndTest(D,G,z_dimension):    # 可视化结果
    sample = iter(test_loader)
    images, labels = sample.next()
    y_one_hot = torch.zeros(labels.shape[0], 10).scatter_(1,labels.view(labels.shape[0], 1), 1).to(device)
    num_img = images.size(0) # 获取样本个数
    with torch.no_grad():
        z = torch.randn(num_img, z_dimension).to(device) # 生成随机数
        fake_img = G(z, y_one_hot)
    fake_images = to_img(fake_img.cpu().data) # 生成模拟样本
    rel = torch.cat([to_img(images[:10]), fake_images[:10]], axis=0)
    imshow(torchvision.utils.make_grid(rel, nrow=10))
    print(labels[:10])

2.6コードコンバット:関数の呼び出しとモデルのトレーニング----WGAN-gp-228.py(パート6)

# 1.6 调用函数并训练模型:实例化判别器和生成器模型,并调用函数进行训练
if __name__ == '__main__':
    z_dimension = 40  # 设置输入随机数的维度

    D = WGAN_D().to(device)  # 实例化判别器
    G = WGAN_G(z_dimension).to(device)  # 实例化生成器
    train(D, G, './w_img', z_dimension) # 训练模型
    displayAndTest(D, G, z_dimension) # 输出可视化

g_lossの絶対値が徐々に減少し、d_lossの絶対値が徐々に増加していることがわかります。これは、生成された準サンプルの品質が向上していることを示しています。ローカルパスのw_imgフォルダーには、30枚の写真があり、そのうち3枚がここにリストされています。


トレーニングプロセス中の3つの画像(2行ごとに1つ)を表示します。これらの画像は、それぞれ1回目、18回目、および30回目のトレーニングの反復後の出力結果です。各画像の最初の行はサンプルデータであり、2番目の行は生成されたシミュレーションデータです。

WGAN-gpの弁別器の厳しい要件の下で、ジェネレータによって生成されたシミュレートされたデータはますます現実的になっていると結論付けられます。生成されたカテゴリの情報が追加されていないため、生成された結果から、サンプルデータが生成されたシミュレーションデータカテゴリに対応していないことがわかります。カテゴリ固有の効果は、条件付きGANを使用して実現できます。
 

 3コードの概要(WGAN-gp-228.py)

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.autograd as autograd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# 1.1 引入模块并载入样本:定义基本函数,加载FashionMNIST数据集
def to_img(x):
    x = 0.5 * (x+1)
    x = x.clamp(0,1)
    x = x.view(x.size(0),1,28,28)
    return x

def imshow(img,filename = None):
    npimg = img.numpy()
    plt.axis('off')
    array = np.transpose(npimg,(1,2,0))
    if filename != None:
        matplotlib.image.imsave(filename,array)
    else:
        plt.imshow(array)
        # plt.savefig(filename) # 保存图片 注释掉,因为会报错,暂时不知道什么原因 2022.3.26 15:20
        plt.show()

img_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ]
)

data_dir = './fashion_mnist'

train_dataset = torchvision.datasets.FashionMNIST(data_dir,train=True,transform=img_transform,download=True)
train_loader = DataLoader(train_dataset,batch_size=1024,shuffle=True)
# 测试数据集
val_dataset = torchvision.datasets.FashionMNIST(data_dir,train=False,transform=img_transform)
test_loader = DataLoader(val_dataset,batch_size=10,shuffle=False)
# 指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# 1.2 实现生成器和判别器 :因为复杂部分都放在loss值的计算方面了,所以生成器和判别器就会简单一些。
# 生成器和判别器各自有两个卷积和两个全连接层。生成器最终输出与输入图片相同维度的数据作为模拟样本。
# 判别器的输出不需要有激活函数,并且输出维度为1的数值用来表示结果。
# 在GAN模型中,因判别器的输入则是具体的样本数据,要区分每个数据的分布特征,所以判别器使用实例归一化,
class WGAN_D(nn.Module): # 定义判别器类D :有两个卷积和两个全连接层
    def __init__(self,inputch=1):
        super(WGAN_D, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(inputch,64,4,2,1), # 输出形状为[batch,64,28,28]
            nn.LeakyReLU(0.2,True),
            nn.InstanceNorm2d(64,affine=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64,128,4,2,1),# 输出形状为[batch,64,14,14]
            nn.LeakyReLU(0.2,True),
            nn.InstanceNorm2d(128,affine=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(128*7*7,1024),
            nn.LeakyReLU(0.2,True)
        )
        self.fc2 = nn.Sequential(
            nn.InstanceNorm1d(1,affine=True),
            nn.Flatten(),
            nn.Linear(1024,1)
        )
    def forward(self,x,*arg): # 正向传播
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        x = x.reshape(x.size(0),1,-1)
        x = self.fc2(x)
        return x.view(-1,1).squeeze(1)

# 在GAN模型中,因生成器的初始输入是随机值,所以生成器使用批量归一化。
class WGAN_G(nn.Module): # 定义生成器类G:有两个卷积和两个全连接层
    def __init__(self,input_size,input_n=1):
        super(WGAN_G, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(input_size * input_n,1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024,7*7*128),
            nn.ReLU(True),
            nn.BatchNorm1d(7*7*128)
        )
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(128,64,4,2,padding=1,bias=False), # 输出形状为[batch,64,14,14]
            nn.ReLU(True),
            nn.BatchNorm2d(64)
        )
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose2d(64,1,4,2,padding=1,bias=False), # 输出形状为[batch,64,28,28]
            nn.Tanh()
        )
    def forward(self,x,*arg): # 正向传播
        x = self.fc1(x)
        x = self.fc2(x)
        x = x.view(x.size(0),128,7,7)
        x = self.upsample1(x)
        img = self.upsample2(x)
        return img

# 1.3 定义函数compute_gradient_penalty()完成梯度惩罚项
# 惩罚项的样本X_inter由一部分Pg分布和一部分Pr分布组成,同时对D(X_inter)求梯度,并计算梯度与1的平方差,最终得到gradient_penalties
lambda_gp = 10
# 计算梯度惩罚项
def compute_gradient_penalty(D,real_samples,fake_samples,y_one_hot):
    # 获取一个随机数,作为真假样本的采样比例
    eps = torch.FloatTensor(real_samples.size(0),1,1,1).uniform_(0,1).to(device)
    # 按照eps比例生成真假样本采样值X_inter
    X_inter = (eps * real_samples + ((1-eps)*fake_samples)).requires_grad_(True)
    d_interpolates = D(X_inter,y_one_hot)
    fake = torch.full((real_samples.size(0),),1,device=device) # 计算梯度输出的掩码,在本例中需要对所有梯度进行计算,故需要按照样本个数生成全为1的张量。
    # 求梯度
    gradients = autograd.grad(outputs=d_interpolates, # 输出值outputs,传入计算过的张量结果
                              inputs=X_inter,# 待求梯度的输入值inputs,传入可导的张量,即requires_grad=True
                              grad_outputs=fake, # 传出梯度的掩码grad_outputs,使用1和0组成的掩码,在计算梯度之后,会将求导结果与该掩码进行相乘得到最终结果。
                              create_graph=True,
                              retain_graph=True,
                              only_inputs=True
                              )[0]
    gradients = gradients.view(gradients.size(0),-1)
    gradient_penaltys = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
    return gradient_penaltys

# 1.4 定义模型的训练函数
# 定义函数train(),实现模型的训练过程。
# 在函数train()中,按照对抗神经网络专题(一)中的式(8-24)实现模型的损失函数。
# 判别器的loss为D(fake_samples)-D(real_samples)再加上联合分布样本的梯度惩罚项gradient_penalties,其中fake_samples为生成的模拟数据,real_Samples为真实数据,
# 生成器的loss为-D(fake_samples)。
def train(D,G,outdir,z_dimension,num_epochs=30):
    d_optimizer = torch.optim.Adam(D.parameters(),lr=0.001) # 定义优化器
    g_optimizer = torch.optim.Adam(G.parameters(),lr=0.001)

    os.makedirs(outdir,exist_ok=True) # 创建输出文件夹
    # 在函数train()中,判别器和生成器是分开训练的。让判别器学习的次数多一些,判别器每训练5次,生成器优化1次。
    # WGAN_gp不会因为判别器准确率太高而引起生成器梯度消失的问题,所以好的判别器会让生成器有更好的模拟效果。
    for epoch in range(num_epochs):
        for i,(img,lab) in enumerate(train_loader):
            num_img = img.size(0)

            # 训练判别器
            real_img = img.to(device)
            y_one_hot = torch.zeros(lab.shape[0],10).scatter_(1,lab.view(lab.shape[0],1),1).to(device)
            for ii in range(5): # 循环训练5次
                d_optimizer.zero_grad() # 梯度清零
                # 对real_img进行判别
                real_out = D(real_img,y_one_hot)
                # 生成随机值
                z = torch.randn(num_img,z_dimension).to(device)
                fake_img = G(z,y_one_hot) # 生成fake_img
                fake_out = D(fake_img,y_one_hot) # 对fake_img进行判别
                # 计算梯度惩罚项
                gradient_penalty = compute_gradient_penalty(D,real_img.data,fake_img.data,y_one_hot)
                # 计算判别器的loss
                d_loss = -torch.mean(real_out)+torch.mean(fake_out)+gradient_penalty
                d_loss.backward()
                d_optimizer.step()

            # 训练生成器
            for ii in range(1): # 训练一次
                g_optimizer.zero_grad() # 梯度清0
                z = torch.randn(num_img,z_dimension).to(device)
                fake_img = G(z,y_one_hot)
                fake_out = D(fake_img,y_one_hot)
                g_loss =  -torch.mean(fake_out)
                g_loss.backward()
                g_optimizer.step()
        # 输出可视化结果,并将生成的结果以图片的形式存储在硬盘中
        fake_images = to_img(fake_img.cpu().data)
        real_images = to_img(real_img.cpu().data)
        rel = torch.cat([to_img(real_images[:10]), fake_images[:10]], axis=0)
        imshow(torchvision.utils.make_grid(rel, nrow=10),os.path.join(outdir, 'fake_images-{}.png'.format(epoch + 1)))
        # 输出训练结果
        print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epochs, d_loss.data, g_loss.data,real_out.data.mean(), fake_out.data.mean()))
    # 保存训练模型
    torch.save(G.state_dict(), os.path.join(outdir, 'generator.pth'))
    torch.save(D.state_dict(), os.path.join(outdir, 'discriminator.pth'))

# 1.5 定义函数,实现可视化模型结果:获取一部分测试数据,显示由模型生成的模拟数据。
def displayAndTest(D,G,z_dimension):    # 可视化结果
    sample = iter(test_loader)
    images, labels = sample.next()
    y_one_hot = torch.zeros(labels.shape[0], 10).scatter_(1,labels.view(labels.shape[0], 1), 1).to(device)
    num_img = images.size(0) # 获取样本个数
    with torch.no_grad():
        z = torch.randn(num_img, z_dimension).to(device) # 生成随机数
        fake_img = G(z, y_one_hot)
    fake_images = to_img(fake_img.cpu().data) # 生成模拟样本
    rel = torch.cat([to_img(images[:10]), fake_images[:10]], axis=0)
    imshow(torchvision.utils.make_grid(rel, nrow=10))
    print(labels[:10])

# 1.6 调用函数并训练模型:实例化判别器和生成器模型,并调用函数进行训练
if __name__ == '__main__':
    z_dimension = 40  # 设置输入随机数的维度

    D = WGAN_D().to(device)  # 实例化判别器
    G = WGAN_G(z_dimension).to(device)  # 实例化生成器
    train(D, G, './w_img', z_dimension) # 训练模型
    displayAndTest(D, G, z_dimension) # 输出可视化

おすすめ

転載: blog.csdn.net/qq_39237205/article/details/123756676