Tabla de contenido
Versión de perceptrón multicapa:
NETO
Versión de perceptrón multicapa:
##GAN网络,多层感知器版
##判别网络
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.dis = nn.Sequential(
nn.Linear(28 * 28, 256),
# nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
# nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
# nn.Sigmoid() #结果在0~1之间
)
def forward(self, x):
x = self.dis(x)
return x
##生成网络
class generator(nn.Module):
def __init__(self, in_size = 96):
super(generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(in_size, 1024),
#顺别说一下如果隐藏层是256的话,效果挺差的
# nn.BatchNorm1d(256),
nn.ReLU(True),
nn.Linear(1024, 1024),
# nn.BatchNorm1d(256),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh() ##产生的结果在-1 ~ 1 之间
)
def forward(self, x):
x = self.gen(x)
return x
Versión de convolución
###GAN 卷积版
class DC_discriminator(nn.Module):
def __init__(self):
super(DC_discriminator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 32, 5, 1), #
nn.LeakyReLU(0.2, True),
nn.MaxPool2d(2,2), #
nn.Conv2d(32, 64, 5, 1), #
nn.LeakyReLU(0.02, True),
nn.MaxPool2d(2, 2) #
)
self.fc = nn.Sequential(
nn.Linear(1024, 1024),
# nn.BatchNorm2d(1024),
nn.LeakyReLU(0.02, True),
nn.Linear(1024, 1)
)
def forward(self, x):
x = self.conv1(x)
x = x.view(x.shape[0],-1)
x = self.fc(x)
return x
##生成网络
class DC_generator(nn.Module):
def __init__(self, in_size = 96):
super(DC_generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(in_size, 1024),
nn.ReLU(True),
nn.BatchNorm1d(1024),
nn.Linear(1024, 7 * 7 * 128),
nn.ReLU(True),
nn.BatchNorm1d(7 * 7 * 128),
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.ReLU(True),
nn.BatchNorm2d(64),
nn.ConvTranspose2d(64, 3, 4, 2, padding=1),
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.shape[0], 128, 7, 7)
x = self.conv(x)
return x
Función de pérdida:
## 定义对抗网络的损失函数
## MSE + KLD
bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
size = logits_real.shape[0]
true_labels = torch.autograd.Variable(torch.ones(size, 1)).float().cuda()
false_labels = torch.autograd.Variable(torch.zeros(size, 1)).float().cuda()
loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
return loss
##定义生成网络的损失函数
def generator_loss(logits_fake): # 生成器的 loss
size = logits_fake.shape[0]
true_labels = torch.autograd.Variable(torch.ones(size, 1)).float().cuda()
loss = bce_loss(logits_fake, true_labels)
return loss
## 这里定义 GAN 的损失函数
## 这里用的是最小二乘
def ls_discriminator_loss(logits_real, logits_fake):
loss = 0.5 * ((logits_real - 1) ** 2).mean() + 0.5 * (logits_fake ** 2).mean()
return loss
def ls_generator_loss(logits_fake):
loss = 0.5 * ((logits_fake - 1) ** 2).mean()
return loss
tren:
import torch
import torchvision.utils
import six_Net
import torch.nn as nn
import tqdm
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.datasets import mnist
from torch.utils.data import DataLoader, sampler
from torch import optim
from torchvision.utils import save_image
##设定参数
NUM_TRAIN = 50000
NUM_VAL = 5000
NOISE_DIM = 96
batch_size = 128
def show_images(images): # 定义画图工具
images = np.reshape(images, [images.shape[0], -1])
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return
def preprocess_img(x):
x = transforms.ToTensor()(x) # x (0., 1.)
return (x - 0.5) / 0.5 # x (-1., 1.)
def deprocess_img(x): # x (-1., 1.)
return (x + 1.0) / 2.0 # x (0., 1.)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_set = mnist.MNIST('./data', train=True, transform=preprocess_img, download=False)
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=False)
def get_optimizer(net):
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
return optimizer
def to_img(x):
'''
定义一个函数将最后的结果转换回图片
'''
x = 0.5 * (x + 1.)
x = x.clamp(0, 1) #设定最小为0 ,最大为1
x = x.view(x.shape[0], 1, 28, 28)
return x
##定义训练
def train_gen(D_net, G_net,
D_optimizer, G_optimizer,
discriminator_loss, generator_loss,
num_epochs=10, noise_size=96, num_img=6):
f, a = plt.subplots(num_img, num_img, figsize=(num_img, num_img))
plt.ion() # Turn the interactive mode on, continuously plot
for epoch in range(num_epochs):
print()
for iteration, (ima, _) in enumerate((train_data)):
bs = ima.shape[0]
##判决网络
real_data = torch.autograd.Variable(ima).view(bs, -1).to(device)# 真实数据
logits_real = D_net(real_data) # 判别网络得分
sample_noise = (torch.rand(bs, noise_size) - 0.5 ) / 0.5 # -1 ~ 1 的均匀分布
g_fake_seed = torch.autograd.Variable(sample_noise).to(device)# 假数据
fake_images = G_net(g_fake_seed) # 生成的假的数据
logits_fake = D_net(fake_images) # 判别网络得分
## 判决器的反向传播
d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
D_optimizer.zero_grad()
d_total_error.backward()
D_optimizer.step() # 优化判别网络
## 生成网络
g_fake_seed = torch.autograd.Variable(sample_noise).to(device)# 假数据
fake_images = G_net(g_fake_seed) # 生成的假的数据
gen_logits_fake = D_net(fake_images) ## 放进判决器 看判决器是否能识别出来
g_error = generator_loss(gen_logits_fake)# 生成网络的 loss
##生成网络的反向传播
G_optimizer.zero_grad()
g_error.backward()
G_optimizer.step() # 优化判别网络
## 每隔 20 次画出生成的图片
if iteration % 20 == 0 :
print(f'Epoch: {epoch + 1} | Iter: {iteration} | '
f'D_loss: {d_total_error.cpu().data.numpy()} | '
f'G_loss:{g_error.cpu().data.numpy()}')
im_gen = deprocess_img(fake_images.data.cpu().numpy())
for i in range(num_img ** 2):
a[i // num_img][i % num_img].imshow(np.reshape(im_gen[i], (28, 28)), cmap='gray')
a[i // num_img][i % num_img].set_xticks(())
a[i // num_img][i % num_img].set_yticks(())
plt.suptitle('epoch: {} iteration: {}'.format(epoch, iteration))
plt.pause(0.01)
pic = to_img(fake_images.cpu().data)
torchvision.utils.save_image(pic, f'./out/ima_{epoch + 1}.png')
D_net = six_Net.discriminator().to(device)
G_net = six_Net.generator(NOISE_DIM).to(device)
D_optimizer = get_optimizer(D_net)
G_optimizer = get_optimizer(G_net)
discriminator_loss = six_Net.discriminator_loss
generator_loss = six_Net.generator_loss
train_gen(D_net, G_net,
D_optimizer, G_optimizer,
discriminator_loss, generator_loss,
10, 96, 5)
Resumir:
La red consta de dos pequeñas redes, una para discriminación y otra para generar
Red de discriminación:
Primero ingrese la imagen real en el discriminador y obtenga la probabilidad real de 1 bit
Genere aleatoriamente un conjunto de datos, envíelo al generador y obtenga datos falsos
Luego envíe los datos falsos al discriminador para obtener la probabilidad falsa
Finalmente, envíe la probabilidad real + probabilidad falsa al cálculo de la función de pérdida
retropropagación
Generar red:
Datos generados aleatoriamente, enviados al generador para obtener datos falsos
Envía los datos falsos al discriminador para ver si se puede identificar, y la probabilidad obtenida es
Finalmente enviado al cálculo de la función de pérdida
retropropagación
Vale la pena señalar que si se genera la imagen de caracteres de color recibida por el tee, se debe eliminar la capa BN en el discriminador, de lo contrario, el ojo humano no puede ver la imagen generada.