基本的な考え方
概要
GAN は深層学習モデルであり、ランダム ノイズから画像、音声、テキストなどのリアルなデータを生成するために使用される教師なし学習アルゴリズムです。GAN の構造は、Generator と Discriminator という 2 つのニューラル ネットワークで構成されており、これらは互いに競合してモデル全体の学習を推進します。
2 つの主なコンポーネント:
1. ジェネレーター:
ジェネレーターの目的は、ランダム ノイズ (通常は正規分布または一様分布からサンプリングされたベクトル) を現実的なデータ サンプルに変換することです。このプロセスは、ジェネレーターがデータの分布を学習し、実際のデータに似た新しいサンプルの作成を試みると理解できます。最初は、ジェネレーターの出力はランダムである可能性がありますが、トレーニングが進むにつれて、ディスクリミネーターを騙すためのより現実的なデータが徐々に生成されます。
2. ディスクリミネーター (ディスクリミネーター):
識別器のタスクは、入力データ サンプルを分類すること、つまり、それが本物のデータであるか、ジェネレータによって生成された偽のデータであるかを判断することです。ディスクリミネーターはバイナリ分類器であり、その目的は、ジェネレーターによって生成された本物のデータと偽のデータをできるだけ正確に区別することです。
トレーニングプロセス
1. トレーニングの開始時に、ジェネレーターはいくつかの偽のデータ サンプルをランダムに生成し、実際のデータとともにディスクリミネーターに提供します。
2. 弁別器は入力データを分類し、確率推定値を出力します (偽のデータの場合は 0、本物のデータの場合は 1)。
3. 識別器の出力に従って、生成器によって生成されたデータが実際のデータであると判断される確率を計算し、この確率を生成器の「損失」として使用します。
4. 次に、ジェネレータの損失に応じて、ジェネレータがより現実的なデータ サンプルを生成できるように、ジェネレータのパラメータが更新されます。
5. 次に、偽のデータ サンプルのバッチを再度ランダムに生成し、本物のデータとともに識別器に提供し、上記のプロセスを繰り返します。
この競争とゲームのプロセスを通じて、ジェネレーターとディスクリミネーターは徐々に能力を最適化し、ジェネレーターは非常に現実的なデータ サンプルを生成できるようになりますが、ディスクリミネーターは本物と偽物を正確に区別できなくなります。
コードとコメント
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Hyper Parameters
BATCH_SIZE = 64
# 生成器学习率
LR_G = 0.0001 # learning rate for generator
# 判别器学习率
LR_D = 0.0001 # learning rate for discriminator
N_IDEAS = 5 # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15 # it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])
# 定义函数artist_works,用于生成来自著名艺术家的真实画作数据
def artist_works():
a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
paintings = a * np.power(PAINT_POINTS, 2) + (a - 1)
paintings = torch.from_numpy(paintings).float()
return paintings
# 定义生成器(Generator)和判别器(Discriminator)
# 初级画家
G = nn.Sequential(
nn.Linear(N_IDEAS, 128), # 生成器输入为随机噪声数据
nn.ReLU(),
nn.Linear(128, ART_COMPONENTS), # 生成器输出为生成的艺术作品
)
# 初级鉴赏家
D = nn.Sequential(
nn.Linear(ART_COMPONENTS, 128), # 判别器输入为艺术作品数据
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid(), # 判别器输出为对艺术作品的真假概率
)
# 定义两个优化器,分别用于优化生成器和判别器的参数
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
# 开始GAN的训练
plt.ion() # 打开交互式绘图
for step in range(10000):
# 获取来自艺术家的真实画作数据
artist_paintings = artist_works()
# 生成随机的噪声数据
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True)
# 生成器生成假的艺术画作
G_paintings = G(G_ideas)
# 判别器对生成的画作进行判断,试图减小判别器对生成画作的概率
prob_artist1 = D(G_paintings)
# 计算生成器的损失
G_loss = torch.mean(torch.log(1. - prob_artist1))
opt_G.zero_grad() # 清空生成器的梯度
G_loss.backward() # 反向传播计算生成器的梯度
opt_G.step() # 优化生成器的参数
# 判别器对真实画作进行判断,试图增大判别器对真实画作的概率
prob_artist0 = D(artist_paintings)
# 判别器对生成的画作进行判断,试图减小判别器对生成画作的概率
prob_artist1 = D(G_paintings.detach())
# 计算判别器的损失
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
opt_D.zero_grad() # 清空判别器的梯度
D_loss.backward(retain_graph=True) # 反向传播计算判别器的梯度(保留计算图以供下一次计算)
opt_D.step() # 优化判别器的参数
if step % 50 == 0: # 每隔一段时间进行绘图显示
# 绘制生成的画作、上界和下界
plt.cla()
plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting')
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={
'size': 13})
plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={
'size': 13})
plt.ylim((0, 3))
plt.legend(loc='upper right', fontsize=10)
plt.draw()
plt.pause(0.01)
plt.ioff() # 关闭交互式绘图
plt.show() # 展示绘制的图像