简单理解GAN对抗神经网络

1.简介

  • 先上一张,我最喜欢的流程图
  • G是一个生成式的网络,它接收一个随机的噪声z(随机数),通过这个噪声生成图像
  • D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片
    流程图

2.接下来我们将以小例子的形式,了解GAN

2.1 定义我们的超参数

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001		# 生成器的学习率
LR_D = 0.0001		# 判别器的学习率
N_IDEAS = 5			# 随机的想法
ART_CONPONENTS = 15	# 一共15个x
PAINT_POINTS = np.vstack([np.linspace(-1,1,ART_CONPONENTS) for _ in range(BATCH_SIZE)])		# 生成一批x数据

2.2 展示大师的画

# 展示我们美丽的生成数据      
# 大师的画都在最高和最低的画之间
plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0], 2) +1 , c='#74BCFF',lw=3,label='upper bound')      
plt.plot(PAINT_POINTS[0],1*np.power(PAINT_POINTS[0], 2) +0 , c='#FF9359',lw=3,label='lower bound')      
# plt.show()      
      
def artist_words():     # 著名画家的画(真实的数据)      
    # 生成一些噪点      
    a = np.random.uniform(1,2,BATCH_SIZE)[:, np.newaxis]      
    paintings = a * np.power(PAINT_POINTS,2)+(a-1)      
    # paintings = torch.from_numpy(paintings).float()      
    plt.plot(np.linspace(-1,1,ART_CONPONENTS),paintings[1],label='really data',color='black',lw=3)                                                                           
    plt.legend(loc='upper right')      
    plt.show()      
    return paintings  

真正的大师

2.3 定义对抗网络

G = nn.Sequential(                      # 生成器
    nn.Linear(N_IDEAS,128),             # 随机的想法
    nn.ReLU(),
    nn.Linear(128,ART_CONPONENTS),      # 从随机的想法,制作画
    )

D = nn.Sequential(                      # 判别器
    nn.Linear(ART_CONPONENTS,128),
    nn.ReLU(),
    nn.Linear(128,1),
    nn.Sigmoid(),						# 转换为概率
    )

# 定义两个优化器
opt_G = torch.optim.Adam(G.parameters(),lr=LR_G)
opt_D = torch.optim.Adam(D.parameters(),lr=LR_D)

2.4 训练GAN

for step in range(5000):
    artist_paintings = artist_words()           # 真实画家的画
    G_ideas = torch.randn(BATCH_SIZE,N_IDEAS)   # 随机的想法 (64,5)
    G_paintings = G(G_ideas)                    # (64,15)

    prob_artist0 = D(artist_paintings)			# 著名画家的画
    prob_artist1 = D(G_paintings)				# 新手画家的画

    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1-prob_artist1))   # 判别器的误差计算公式
    G_loss = torch.mean(torch.log(1- prob_artist1))								 # 生成器的误差计算公式

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)		# 保留计算图,让我们的G的误差可以传递过去
    opt_D.step()

    opt_G.zero_grad()
    G_loss.backward(retain_graph=True)
    opt_G.step()

2.5 训练中的可视化

plt.ion()

for step in range(5000):
    artist_paintings = artist_words()           # 真实画家的画
    G_ideas = torch.randn(BATCH_SIZE,N_IDEAS)   # 随机的想法 (64,5)
    G_paintings = G(G_ideas)                    # (64,15)

    prob_artist0 = D(artist_paintings)
    prob_artist1 = D(G_paintings)

    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1-prob_artist1))
    G_loss = torch.mean(torch.log(1- prob_artist1))

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)
    opt_D.step()

    opt_G.zero_grad()
    G_loss.backward(retain_graph=True)
    opt_G.step()

    if step % 50 == 0:
        plt.cla()
        plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0],c='#4AD631',lw=3,label='Generated painting')
        plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2) + 1, c='#74BCFF', lw=3,label='upper bound')
        plt.plot(PAINT_POINTS[0],1*np.power(PAINT_POINTS[0],2) + 0, c='#FF9359', lw=3,label='lower bound')

        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)'% prob_artist0.data.numpy().mean(),fontdict={'size':13})
        plt.text(-.5, 2, 'D score=%.2f (-1.38 for G to converge)'% -D_loss.data.numpy(),fontdict={'size':13})
        plt.ylim((0,3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.02)

plt.ioff()
plt.show()
发布了31 篇原创文章 · 获赞 13 · 访问量 9902

猜你喜欢

转载自blog.csdn.net/qq_43497702/article/details/98117006