Registro simple: GAN (red contra generación), pytorch + MNIST

Tabla de contenido

NETO

Versión de perceptrón multicapa:

Versión de convolución

Función de pérdida:

tren:

Resumir:


NETO

Versión de perceptrón multicapa:

##GAN网络,多层感知器版
##判别网络
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(28 * 28, 256),
            # nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 256),
            # nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 1),
            # nn.Sigmoid()        #结果在0~1之间
        )

    def forward(self, x):
        x = self.dis(x)
        return x

##生成网络
class generator(nn.Module):
    def __init__(self, in_size = 96):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(in_size, 1024),
            #顺别说一下如果隐藏层是256的话,效果挺差的
            # nn.BatchNorm1d(256),
            nn.ReLU(True),

            nn.Linear(1024, 1024),
            # nn.BatchNorm1d(256),
            nn.ReLU(True),

            nn.Linear(1024, 784),
            nn.Tanh()           ##产生的结果在-1 ~ 1 之间
        )

    def forward(self, x):
        x = self.gen(x)
        return x

Versión de convolución

###GAN 卷积版
class DC_discriminator(nn.Module):
    def __init__(self):
        super(DC_discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1),      #
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d(2,2),              #

            nn.Conv2d(32, 64, 5, 1),         #
            nn.LeakyReLU(0.02, True),
            nn.MaxPool2d(2, 2)                  #
        )

        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            # nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.02, True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.shape[0],-1)
        x = self.fc(x)
        return x

##生成网络
class DC_generator(nn.Module):
    def __init__(self, in_size = 96):
        super(DC_generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_size, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),

            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128),
        )

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),

            nn.ConvTranspose2d(64, 3, 4, 2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)
        x = self.conv(x)
        return x

Función de pérdida:

## 定义对抗网络的损失函数
## MSE +  KLD
bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
    size = logits_real.shape[0]
    true_labels = torch.autograd.Variable(torch.ones(size, 1)).float().cuda()
    false_labels = torch.autograd.Variable(torch.zeros(size, 1)).float().cuda()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss

##定义生成网络的损失函数
def generator_loss(logits_fake): # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = torch.autograd.Variable(torch.ones(size, 1)).float().cuda()
    loss = bce_loss(logits_fake, true_labels)
    return loss


## 这里定义 GAN 的损失函数
## 这里用的是最小二乘
def ls_discriminator_loss(logits_real, logits_fake):
    loss = 0.5 * ((logits_real - 1) ** 2).mean() + 0.5 * (logits_fake ** 2).mean()
    return loss

def ls_generator_loss(logits_fake):
    loss = 0.5 * ((logits_fake - 1) ** 2).mean()
    return loss

tren:

import torch
import torchvision.utils

import six_Net
import torch.nn as nn
import tqdm
import os

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.datasets import mnist
from torch.utils.data import DataLoader, sampler
from torch import optim
from torchvision.utils import save_image

##设定参数
NUM_TRAIN = 50000
NUM_VAL = 5000

NOISE_DIM = 96
batch_size = 128

def show_images(images): # 定义画图工具
    images = np.reshape(images, [images.shape[0], -1])
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    return

def preprocess_img(x):
    x = transforms.ToTensor()(x)    # x (0., 1.)
    return (x - 0.5) / 0.5          # x (-1., 1.)

def deprocess_img(x):           # x (-1., 1.)
    return (x + 1.0) / 2.0      # x (0., 1.)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_set = mnist.MNIST('./data', train=True, transform=preprocess_img, download=False)
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=False)

def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer

def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)       #设定最小为0 ,最大为1
    x = x.view(x.shape[0], 1, 28, 28)
    return x

##定义训练
def train_gen(D_net, G_net,
              D_optimizer, G_optimizer,
             discriminator_loss, generator_loss,
             num_epochs=10, noise_size=96, num_img=6):
    f, a = plt.subplots(num_img, num_img, figsize=(num_img, num_img))
    plt.ion()  # Turn the interactive mode on, continuously plot

    for epoch in range(num_epochs):
        print()
        for iteration, (ima, _) in enumerate((train_data)):
            bs = ima.shape[0]
            ##判决网络
            real_data = torch.autograd.Variable(ima).view(bs, -1).to(device)# 真实数据
            logits_real = D_net(real_data)      # 判别网络得分

            sample_noise = (torch.rand(bs, noise_size) - 0.5 ) / 0.5 # -1 ~ 1 的均匀分布
            g_fake_seed = torch.autograd.Variable(sample_noise).to(device)# 假数据
            fake_images = G_net(g_fake_seed)    # 生成的假的数据
            logits_fake = D_net(fake_images)    # 判别网络得分

            ## 判决器的反向传播
            d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step()                  # 优化判别网络

            ## 生成网络
            g_fake_seed = torch.autograd.Variable(sample_noise).to(device)# 假数据
            fake_images = G_net(g_fake_seed)    # 生成的假的数据

            gen_logits_fake = D_net(fake_images)    ## 放进判决器 看判决器是否能识别出来
            g_error = generator_loss(gen_logits_fake)# 生成网络的 loss

            ##生成网络的反向传播
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step()  # 优化判别网络

            ## 每隔 20 次画出生成的图片
            if iteration % 20 == 0 :
                print(f'Epoch: {epoch + 1} | Iter: {iteration} | '
                      f'D_loss: {d_total_error.cpu().data.numpy()} | '
                      f'G_loss:{g_error.cpu().data.numpy()}')
                im_gen = deprocess_img(fake_images.data.cpu().numpy())
                for i in range(num_img ** 2):
                    a[i // num_img][i % num_img].imshow(np.reshape(im_gen[i], (28, 28)), cmap='gray')
                    a[i // num_img][i % num_img].set_xticks(())
                    a[i // num_img][i % num_img].set_yticks(())
                plt.suptitle('epoch: {} iteration: {}'.format(epoch, iteration))
                plt.pause(0.01)
        pic = to_img(fake_images.cpu().data)
        torchvision.utils.save_image(pic, f'./out/ima_{epoch + 1}.png')


D_net = six_Net.discriminator().to(device)
G_net = six_Net.generator(NOISE_DIM).to(device)
D_optimizer = get_optimizer(D_net)
G_optimizer = get_optimizer(G_net)
discriminator_loss = six_Net.discriminator_loss
generator_loss = six_Net.generator_loss

train_gen(D_net, G_net,
              D_optimizer, G_optimizer,
             discriminator_loss, generator_loss,
          10, 96, 5)

Resumir:

La red consta de dos pequeñas redes, una para discriminación y otra para generar

Red de discriminación:

Primero ingrese la imagen real en el discriminador y obtenga la probabilidad real de 1 bit

Genere aleatoriamente un conjunto de datos, envíelo al generador y obtenga datos falsos

Luego envíe los datos falsos al discriminador para obtener la probabilidad falsa

Finalmente, envíe la probabilidad real + probabilidad falsa al cálculo de la función de pérdida

retropropagación

Generar red:

Datos generados aleatoriamente, enviados al generador para obtener datos falsos

Envía los datos falsos al discriminador para ver si se puede identificar, y la probabilidad obtenida es

Finalmente enviado al cálculo de la función de pérdida

retropropagación

Vale la pena señalar que si se genera la imagen de caracteres de color recibida por el tee, se debe eliminar la capa BN en el discriminador, de lo contrario, el ojo humano no puede ver la imagen generada.

Supongo que te gusta

Origin blog.csdn.net/qq_42792802/article/details/126164883
Recomendado
Clasificación