ディープラーニング - 対立ネットワーク GAN を生成する

基本的な考え方

概要

GAN は深層学習モデルであり、ランダム ノイズから画像、音声、テキストなどのリアルなデータを生成するために使用される教師なし学習アルゴリズムです。GAN の構造は、Generator と Discriminator という 2 つのニューラル ネットワークで構成されており、これらは互いに競合してモデル全体の学習を推進します。

2 つの主なコンポーネント:

1. ジェネレーター:

ジェネレーターの目的は、ランダム ノイズ (通常は正規分布または一様分布からサンプリングされたベクトル) を現実的なデータ サンプルに変換することです。このプロセスは、ジェネレーターがデータの分布を学習し、実際のデータに似た新しいサンプルの作成を試みると理解できます。最初は、ジェネレーターの出力はランダムである可能性がありますが、トレーニングが進むにつれて、ディスクリミネーターを騙すためのより現実的なデータが徐々に生成されます。

2. ディスクリミネーター (ディスクリミネーター):

識別器のタスクは、入力データ サンプルを分類すること、つまり、それが本物のデータであるか、ジェネレータによって生成された偽のデータであるかを判断することです。ディスクリミネーターはバイナリ分類器であり、その目的は、ジェネレーターによって生成された本物のデータと偽のデータをできるだけ正確に区別することです。

トレーニングプロセス

1. トレーニングの開始時に、ジェネレーターはいくつかの偽のデータ サンプルをランダムに生成し、実際のデータとともにディスクリミネーターに提供します。
2. 弁別器は入力データを分類し、確率推定値を出力します (偽のデータの場合は 0、本物のデータの場合は 1)。
3. 識別器の出力に従って、生成器によって生成されたデータが実際のデータであると判断される確率を計算し、この確率を生成器の「損失」として使用します。
4. 次に、ジェネレータの損失に応じて、ジェネレータがより現実的なデータ サンプルを生成できるように、ジェネレータのパラメータが更新されます。
5. 次に、偽のデータ サンプルのバッチを再度ランダムに生成し、本物のデータとともに識別器に提供し、上記のプロセスを繰り返します。

この競争とゲームのプロセスを通じて、ジェネレーターとディスクリミネーターは徐々に能力を最適化し、ジェネレーターは非常に現実的なデータ サンプルを生成できるようになりますが、ディスクリミネーターは本物と偽物を正確に区別できなくなります。

コードとコメント

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Hyper Parameters
BATCH_SIZE = 64
# 生成器学习率
LR_G = 0.0001  # learning rate for generator
# 判别器学习率
LR_D = 0.0001  # learning rate for discriminator
N_IDEAS = 5  # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15  # it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])


# 定义函数artist_works,用于生成来自著名艺术家的真实画作数据
def artist_works():
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a - 1)
    paintings = torch.from_numpy(paintings).float()
    return paintings


# 定义生成器(Generator)和判别器(Discriminator)
# 初级画家
G = nn.Sequential(
    nn.Linear(N_IDEAS, 128),  # 生成器输入为随机噪声数据
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),  # 生成器输出为生成的艺术作品
)

# 初级鉴赏家
D = nn.Sequential(
    nn.Linear(ART_COMPONENTS, 128),  # 判别器输入为艺术作品数据
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),  # 判别器输出为对艺术作品的真假概率
)

# 定义两个优化器,分别用于优化生成器和判别器的参数
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

# 开始GAN的训练
plt.ion()  # 打开交互式绘图

for step in range(10000):
    # 获取来自艺术家的真实画作数据
    artist_paintings = artist_works()
    # 生成随机的噪声数据
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True)
    # 生成器生成假的艺术画作
    G_paintings = G(G_ideas)

    # 判别器对生成的画作进行判断,试图减小判别器对生成画作的概率
    prob_artist1 = D(G_paintings)
    # 计算生成器的损失
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    opt_G.zero_grad()  # 清空生成器的梯度
    G_loss.backward()  # 反向传播计算生成器的梯度
    opt_G.step()  # 优化生成器的参数

    # 判别器对真实画作进行判断,试图增大判别器对真实画作的概率
    prob_artist0 = D(artist_paintings)
    # 判别器对生成的画作进行判断,试图减小判别器对生成画作的概率
    prob_artist1 = D(G_paintings.detach())
    # 计算判别器的损失
    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))

    opt_D.zero_grad()  # 清空判别器的梯度
    D_loss.backward(retain_graph=True)  # 反向传播计算判别器的梯度(保留计算图以供下一次计算)
    opt_D.step()  # 优化判别器的参数

    if step % 50 == 0:  # 每隔一段时间进行绘图显示
        # 绘制生成的画作、上界和下界
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting')
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={
    
    'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={
    
    'size': 13})
        plt.ylim((0, 3))
        plt.legend(loc='upper right', fontsize=10)
        plt.draw()
        plt.pause(0.01)

plt.ioff()  # 关闭交互式绘图
plt.show()  # 展示绘制的图像

演算結果

ここに画像の説明を挿入

おすすめ

転載: blog.csdn.net/Elon15/article/details/131832395