GAN原理及Pytorch框架实现GAN(比较容易理解)

目录

1.初识GAN

什么是GAN?

GAN应用场景

2.GAN原理结构

(1)生成对抗网络子网络

(2)结构图

(1)生成器 

(2)判别器

(3)训练技巧 

3.GAN网络模型选择

(1)生成模型

(2)判别模型

4.GAN训练目标函数

(1)生成模型

(2)判别模型

5.训练算法

6.GAN代码实现

7.mainWindow窗口显示生成器生成的图片

拓展


1.初识GAN

  • 什么是GAN?

    • GAN(Generative Adversarial Networks):生成对抗网络;
    • GAN是当前人工智能领域最为重要的研究热点之一,并且应用非常的广泛;
  • 2014年,Universite de Montreal 大学Yoshua Bengio(2018年图灵奖获得者)的学生Ian Goodfellow提出 生成对抗网络(Generative adj-terminal networks,简称 GAN),从而开辟了深度学习最赤手可热的研究方向。
  • 从2014-2019年,GAN的研究稳步推进,研究捷报频传,最新的GAN算法在图片生成上的效果甚至达到了肉眼很难分辨的程度。由于GAN的发明,Ian Goodfello荣获GAN之父称号,并获得了2017年麻省理工大学科技评论颁奖的35 Innovators Goodfellow奖项。
  • 该方法利用了两个网络,一个称为生成网络,另一个称为鉴别网络,可用于以音频、视频和文本的形式产生不同寻常的创造性输出。他的这项研究,在人工智能文献中被广泛引用。
  • GAN应用场景

    • 图像编辑:给定一张图像,可以在该图像的基础之上生成各种各样的图像;
    • 恶意攻击检测:深度学习生成的模型是可以被黑客攻击,利用甚至控制的。为了对抗这样的逆向攻击(adversarial attacks),可以训练对抗神经网络去生成更多的虚假训练数据作为假想敌,让模型在演习中去识别出这些虚假数据,GAN生成的虚假数据让正在做分类的模型更加稳健;
    • 数据生成:例如医疗领域,缺少训练数据是应用深度学习的最大障碍。数据增强的传统做法是将原图像拉伸旋转剪切,但这毕竟还是原来的图像,通过使用GAN,能够生成更多类似的数据;
    • 注意力预测:人类在看一张图片时,往往只关注特定的部分,而通过GAN模型,可以预测出人类关心的区域在哪里。
    • 三位结构生成:pix2vox是一个基于GAN的开源工具,能够根据手绘的二维图片,生成对应的三维结构,不止有对应的形状,还会生成对应的颜色,有了这样的工具,就能降低3D建模的门槛,从而让3D打印更容易的落地。

2.GAN原理结构

提示:下面的原理解释可能对于读者来说比较枯燥无味,但是还是希望读者可以坚持看完原理,因为只有这样你才能真正的理解GAN的实现原理。

(1)生成对抗网络子网络

 GAN包含:生成网络(Generator Network)和判别网络(Discriminator Network),其中生成网络Gen负责学习样本的真实分布,判别网络Dis负责将对生成网络生成的样本和真实样本分别进行判别。

(2)结构图

(1)生成器 

生成模型以随机噪声(Random noise)或者类似的控制变量作为输入,生成器一般使用多层的神经网络实现,其输出为生成的样本,也就是一张假的图片(fake image);这样样本和真实给定的样本一起给判别器模型训练。

"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/24 14:21
"""
import torch
import numpy as np
#对于生成器,输入的为正态分布随机数
#输出为: [1,28,28]图片

class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(in_features=100,out_features=256),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=256,out_features=512),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=512,out_features=784),
            torch.nn.Tanh()#对于生成器使用tanh激活函数更好
        )
    def forward(self,input):
        x = self.fc(input)
        img = x.view(-1,28,28)
        return img

(2)判别器

判别器模型是一个二分类器,判别一个样本是真实的样本还是生成器生成的样本,一般也是使用神经网络实现。

"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/24 14:21
"""
import torch
import numpy as np

#判别器的输入为一张图片
#输出为二分类的概率值
#判别器对log(1 - D(G(z)))的判别作为生成器的损失值

class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(in_features=784,out_features=512),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_features=512,out_features=256),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_features=256,out_features=1),
            torch.nn.Sigmoid()
        )
    def forward(self,input):
        x = input.view(-1,784)
        x = self.fc(x)
        return x

(3)训练技巧 

  • 对于生成模型:训练目标是让生成的数据尽可能的与真实数据相似,最小化判别模型的判别准确率。
  • 对于判别模型:训练目标是最大化判别准确率,即区分样本是真实样本还是生成器生成的样本。

可以发现,这个过程是矛盾的,因此:

  • 在训练的过程中采用交替优化的方式,每一次迭代时分为两个阶段:
    • 第一个阶段:首先固定判别模型,优化生成模型,使得生成的数据备判别模型判定为真样本的概率尽可能的高。
    • 第二个阶段:固定生成模型,优化盘被模型,提高判别模型的分类准确率。

提示:在训练过程中,生成器努力地让生成的图像更加的真实,而判别器则努力地识别生成器图片的真假,这是是一个相互博弈的过程,互相提升自己,也就是不断的进行对抗的过程。随着训练的进行,生成模型产生的样本和真实样本几乎没有什么差别,判别模型也无法准确的判别一个样本的真假,此时的分类错误率为0.5(那什均衡)

3.GAN网络模型选择

生成对抗网络是一个抽象的框架,并没有指定生成模型和判别模型具体为哪一种模型,可以是神经网络模型,也可以是卷积神经网络 模型或者其他的机器学习模型。

(1)生成模型

        在本文中,生气模型选择是神经网络模型。根据类型等输入变量来生成图像之类的样本数据,生成模型接收的输入是类别之类的隐变量和随机噪声,输出与训练样本相似的样本数据(比如图片之类的)。

(2)判别模型

        判别模型一般用分类问题的神经网络,用于区分样本的真假(给定的真实数据和生成器生成的数据),是一个二分类问题。

4.GAN训练目标函数

提示:在确定生成模型和判别模型之前,首先了解一下logistic回归模型:

logistic回归即对数概率回归,是一种二分类问题的分类算法,使用sigmoid函数估计出样本属于正样本的概率(关于细节推导,建议看《机器学习原理,算法与应用》)。

logistic回归似然函数:

  • 回归对数函数和生成对抗区别:
    • logistic回归在训练达到最优点处时,负样本的预测输出接近于0;
    • 生成对抗网络中判别模型对抗样本的输出概率值在最优点处接近于0.5,。 

(1)生成模型

(2)判别模型

5.训练算法

  

6.GAN代码实现

提示:代码放在了Github上,读者自行下载:https://github.com/KeepTryingTo/Pytorch-GAN

 

7.mainWindow窗口显示生成器生成的图片

提示:这里编写了一个显示生成器显示图片的程序(mainWindow.py),加载之前训练之后保存的生成器模型,之后可使用该模型进行随机生成数字图片,如下:

(1)运行mainWindow.py 初始界面如下

 (2)点击生成图片按钮,每一次的点击生成的数字都不是一样的。

拓展

pytorch中的detach作用

参考文章:

《TensorFlow深度学习》

《机器学习原理,算法与应用》

https://www.jiqizhixin.com/articles/2019-04-15-6

https://b23.tv/6P7M8mh

猜你喜欢

转载自blog.csdn.net/Keep_Trying_Go/article/details/130362296
GAN