Pytorch实现一个简单的生成对抗网络GAN

最近看了一些GAN的资料,把自己易混淆的内容做一个总结

生成式模型

        我们以往通常接触到的深度学习模型一般都是些判别模型,即通过训练样本训练模型,然后利用模型对新样本进行判别或预测。判别模型体现了深度学习的学习能力,然而,人工智能的强大,不应只有从已知中学习,还应该有创造能力,才就真正有趣。而生成式模型所体现的就是深度学习的创造能力。与判别式模型的工作流程恰恰相反,生成式模型是根据规则生成新的样本。
        生成式模型,主要包括变分自编码器(VAE)生成式对抗网络(GAN)这两种思路。VAE基于贝叶斯推理,其目的是潜在地建模,从模型中采样新的数据。GAN是利用博弈论思想,以求得达到纳什均衡(如何通俗的理解纳什均衡点?)的判别器网络(D)和生成器网络(G)。

GAN架构

        GAN的直观理解,可以想象一个名画伪造者想伪造一幅达芬奇的画作,开始时伪造技术不精,但他将自己的一些赝品和达芬奇的作品混在一起,请一个艺术鉴赏家进行真实性评估,并向伪造者反馈真伪程度。伪造者根据反馈,改进自己的赝品。随着时间的推移,那么造假者的造假能力越来却强,鉴赏家的能力也越来越强。而赝品,则越来越像真画。

        以上,便是GAN的原理。一个造假者G,一个鉴赏家D。他们训练的目的都是为了打败对方。

        下图是我从书中截取的GAN架构图,简单明了。

GAN的损失函数

     

  假设x表示图像,D(x)表示判别网络,是一个二元分类器,那么它的输出即为图片x来自训练数据(而不是产生网络输出的假图片)的概率。对于产生网络,首先定义从标准正态分布种采样的数据z,则G(z)表示的是将向量z映射到空间的生成器函数。G的目标是估计训练数据的分布(Pdate)以生成假样本。因此D(G(z))是产生网络G的输出是真实图像的概率。即判别网络D和产生网络G在做一个极大极小的博弈,其中D试图最大化它正确分辨真假数据(logD(x))的概率,而G试图最小化D预测其输出是假的概率(log(1-d(G(x))))。

Pytorch实现一个GAN

      

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt



# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001           # learning rate for generator
LR_D = 0.0001           # learning rate for discriminator
N_IDEAS = 5             # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15     # it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])
print(PAINT_POINTS)
# show our beautiful painting range
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.legend(loc='upper right')
plt.show()


def artist_works():     # painting from the famous artist (real target)
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a-1)
    paintings = torch.from_numpy(paintings).float()
    return paintings

G = nn.Sequential(                      # Generator
    nn.Linear(N_IDEAS, 128),            # random ideas (could from normal distribution)
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),     # making a painting from these random ideas
)

D = nn.Sequential(                      # Discriminator
    nn.Linear(ART_COMPONENTS, 128),     # receive art work either from the famous artist or a newbie like G
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),                       # tell the probability that the art work is made by artist
)

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

plt.ion()   # something about continuous plotting

for step in range(10000):
    artist_paintings = artist_works()           # real painting from artist
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideas
    G_paintings = G(G_ideas)                    # fake painting from G (random ideas)

    prob_artist0 = D(artist_paintings)          # D try to increase this prob
    prob_artist1 = D(G_paintings)               # D try to reduce this prob

    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)      # reusing computational graph
    opt_D.step()

    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()

    if step % 50 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.01)

plt.ioff()
plt.show()

参考资料:《Python深度学习 基于pytorch》,《深度学习与图像识别 原理与实践》,莫烦Python

发布了134 篇原创文章 · 获赞 38 · 访问量 9万+

猜你喜欢

转载自blog.csdn.net/rytyy/article/details/105128976