目录
NET
多层感知器版:
##GAN网络,多层感知器版
##判别网络
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.dis = nn.Sequential(
nn.Linear(28 * 28, 256),
# nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
# nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
# nn.Sigmoid() #结果在0~1之间
)
def forward(self, x):
x = self.dis(x)
return x
##生成网络
class generator(nn.Module):
def __init__(self, in_size = 96):
super(generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(in_size, 1024),
#顺别说一下如果隐藏层是256的话,效果挺差的
# nn.BatchNorm1d(256),
nn.ReLU(True),
nn.Linear(1024, 1024),
# nn.BatchNorm1d(256),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh() ##产生的结果在-1 ~ 1 之间
)
def forward(self, x):
x = self.gen(x)
return x
卷积版
###GAN 卷积版
class DC_discriminator(nn.Module):
def __init__(self):
super(DC_discriminator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 32, 5, 1), #
nn.LeakyReLU(0.2, True),
nn.MaxPool2d(2,2), #
nn.Conv2d(32, 64, 5, 1), #
nn.LeakyReLU(0.02, True),
nn.MaxPool2d(2, 2) #
)
self.fc = nn.Sequential(
nn.Linear(1024, 1024),
# nn.BatchNorm2d(1024),
nn.LeakyReLU(0.02, True),
nn.Linear(1024, 1)
)
def forward(self, x):
x = self.conv1(x)
x = x.view(x.shape[0],-1)
x = self.fc(x)
return x
##生成网络
class DC_generator(nn.Module):
def __init__(self, in_size = 96):
super(DC_generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(in_size, 1024),
nn.ReLU(True),
nn.BatchNorm1d(1024),
nn.Linear(1024, 7 * 7 * 128),
nn.ReLU(True),
nn.BatchNorm1d(7 * 7 * 128),
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.ReLU(True),
nn.BatchNorm2d(64),
nn.ConvTranspose2d(64, 3, 4, 2, padding=1),
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.shape[0], 128, 7, 7)
x = self.conv(x)
return x
损失函数:
## 定义对抗网络的损失函数
## MSE + KLD
bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
size = logits_real.shape[0]
true_labels = torch.autograd.Variable(torch.ones(size, 1)).float().cuda()
false_labels = torch.autograd.Variable(torch.zeros(size, 1)).float().cuda()
loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
return loss
##定义生成网络的损失函数
def generator_loss(logits_fake): # 生成器的 loss
size = logits_fake.shape[0]
true_labels = torch.autograd.Variable(torch.ones(size, 1)).float().cuda()
loss = bce_loss(logits_fake, true_labels)
return loss
## 这里定义 GAN 的损失函数
## 这里用的是最小二乘
def ls_discriminator_loss(logits_real, logits_fake):
loss = 0.5 * ((logits_real - 1) ** 2).mean() + 0.5 * (logits_fake ** 2).mean()
return loss
def ls_generator_loss(logits_fake):
loss = 0.5 * ((logits_fake - 1) ** 2).mean()
return loss
train:
import torch
import torchvision.utils
import six_Net
import torch.nn as nn
import tqdm
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.datasets import mnist
from torch.utils.data import DataLoader, sampler
from torch import optim
from torchvision.utils import save_image
##设定参数
NUM_TRAIN = 50000
NUM_VAL = 5000
NOISE_DIM = 96
batch_size = 128
def show_images(images): # 定义画图工具
images = np.reshape(images, [images.shape[0], -1])
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return
def preprocess_img(x):
x = transforms.ToTensor()(x) # x (0., 1.)
return (x - 0.5) / 0.5 # x (-1., 1.)
def deprocess_img(x): # x (-1., 1.)
return (x + 1.0) / 2.0 # x (0., 1.)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_set = mnist.MNIST('./data', train=True, transform=preprocess_img, download=False)
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=False)
def get_optimizer(net):
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
return optimizer
def to_img(x):
'''
定义一个函数将最后的结果转换回图片
'''
x = 0.5 * (x + 1.)
x = x.clamp(0, 1) #设定最小为0 ,最大为1
x = x.view(x.shape[0], 1, 28, 28)
return x
##定义训练
def train_gen(D_net, G_net,
D_optimizer, G_optimizer,
discriminator_loss, generator_loss,
num_epochs=10, noise_size=96, num_img=6):
f, a = plt.subplots(num_img, num_img, figsize=(num_img, num_img))
plt.ion() # Turn the interactive mode on, continuously plot
for epoch in range(num_epochs):
print()
for iteration, (ima, _) in enumerate((train_data)):
bs = ima.shape[0]
##判决网络
real_data = torch.autograd.Variable(ima).view(bs, -1).to(device)# 真实数据
logits_real = D_net(real_data) # 判别网络得分
sample_noise = (torch.rand(bs, noise_size) - 0.5 ) / 0.5 # -1 ~ 1 的均匀分布
g_fake_seed = torch.autograd.Variable(sample_noise).to(device)# 假数据
fake_images = G_net(g_fake_seed) # 生成的假的数据
logits_fake = D_net(fake_images) # 判别网络得分
## 判决器的反向传播
d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
D_optimizer.zero_grad()
d_total_error.backward()
D_optimizer.step() # 优化判别网络
## 生成网络
g_fake_seed = torch.autograd.Variable(sample_noise).to(device)# 假数据
fake_images = G_net(g_fake_seed) # 生成的假的数据
gen_logits_fake = D_net(fake_images) ## 放进判决器 看判决器是否能识别出来
g_error = generator_loss(gen_logits_fake)# 生成网络的 loss
##生成网络的反向传播
G_optimizer.zero_grad()
g_error.backward()
G_optimizer.step() # 优化判别网络
## 每隔 20 次画出生成的图片
if iteration % 20 == 0 :
print(f'Epoch: {epoch + 1} | Iter: {iteration} | '
f'D_loss: {d_total_error.cpu().data.numpy()} | '
f'G_loss:{g_error.cpu().data.numpy()}')
im_gen = deprocess_img(fake_images.data.cpu().numpy())
for i in range(num_img ** 2):
a[i // num_img][i % num_img].imshow(np.reshape(im_gen[i], (28, 28)), cmap='gray')
a[i // num_img][i % num_img].set_xticks(())
a[i // num_img][i % num_img].set_yticks(())
plt.suptitle('epoch: {} iteration: {}'.format(epoch, iteration))
plt.pause(0.01)
pic = to_img(fake_images.cpu().data)
torchvision.utils.save_image(pic, f'./out/ima_{epoch + 1}.png')
D_net = six_Net.discriminator().to(device)
G_net = six_Net.generator(NOISE_DIM).to(device)
D_optimizer = get_optimizer(D_net)
G_optimizer = get_optimizer(G_net)
discriminator_loss = six_Net.discriminator_loss
generator_loss = six_Net.generator_loss
train_gen(D_net, G_net,
D_optimizer, G_optimizer,
discriminator_loss, generator_loss,
10, 96, 5)
总结:
网络由两个小网络组成,一个负责判别,一个负责生成
判别网络:
先将真正的图片输入 判别器 ,得到的是1位的 真实概率
随机生成一组数据,送入 生成器 ,得到 假数据
再将 假数据 送入 判别器 中,得到 虚假概率
最后将 真实概率 + 虚假概率 ,送入 损失函数计算
反向传播
生成网络:
随机生成的数据,送入 生成器 得到 虚假数据
将 虚假数据 送入 判别器 看它能不能甄别出来,得到的概率
最后送入 损失函数计算
反向传播
值得注意的是:如果生成时三通到的彩色人物图片的话,要去掉判别器里的BN层,不然生成的图片人眼根本看不出来