Simple record: GAN (against generation network), pytorch + MNIST

Table of contents

NET

Multilayer perceptron version:

Convolution version

Loss function:

train:

Summarize:


NET

Multilayer perceptron version:

##GAN网络,多层感知器版
##判别网络
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(28 * 28, 256),
            # nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 256),
            # nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 1),
            # nn.Sigmoid()        #结果在0~1之间
        )

    def forward(self, x):
        x = self.dis(x)
        return x

##生成网络
class generator(nn.Module):
    def __init__(self, in_size = 96):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(in_size, 1024),
            #顺别说一下如果隐藏层是256的话,效果挺差的
            # nn.BatchNorm1d(256),
            nn.ReLU(True),

            nn.Linear(1024, 1024),
            # nn.BatchNorm1d(256),
            nn.ReLU(True),

            nn.Linear(1024, 784),
            nn.Tanh()           ##产生的结果在-1 ~ 1 之间
        )

    def forward(self, x):
        x = self.gen(x)
        return x

Convolution version

###GAN 卷积版
class DC_discriminator(nn.Module):
    def __init__(self):
        super(DC_discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1),      #
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d(2,2),              #

            nn.Conv2d(32, 64, 5, 1),         #
            nn.LeakyReLU(0.02, True),
            nn.MaxPool2d(2, 2)                  #
        )

        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            # nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.02, True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.shape[0],-1)
        x = self.fc(x)
        return x

##生成网络
class DC_generator(nn.Module):
    def __init__(self, in_size = 96):
        super(DC_generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_size, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),

            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128),
        )

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),

            nn.ConvTranspose2d(64, 3, 4, 2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)
        x = self.conv(x)
        return x

Loss function:

## 定义对抗网络的损失函数
## MSE +  KLD
bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
    size = logits_real.shape[0]
    true_labels = torch.autograd.Variable(torch.ones(size, 1)).float().cuda()
    false_labels = torch.autograd.Variable(torch.zeros(size, 1)).float().cuda()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss

##定义生成网络的损失函数
def generator_loss(logits_fake): # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = torch.autograd.Variable(torch.ones(size, 1)).float().cuda()
    loss = bce_loss(logits_fake, true_labels)
    return loss


## 这里定义 GAN 的损失函数
## 这里用的是最小二乘
def ls_discriminator_loss(logits_real, logits_fake):
    loss = 0.5 * ((logits_real - 1) ** 2).mean() + 0.5 * (logits_fake ** 2).mean()
    return loss

def ls_generator_loss(logits_fake):
    loss = 0.5 * ((logits_fake - 1) ** 2).mean()
    return loss

train:

import torch
import torchvision.utils

import six_Net
import torch.nn as nn
import tqdm
import os

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.datasets import mnist
from torch.utils.data import DataLoader, sampler
from torch import optim
from torchvision.utils import save_image

##设定参数
NUM_TRAIN = 50000
NUM_VAL = 5000

NOISE_DIM = 96
batch_size = 128

def show_images(images): # 定义画图工具
    images = np.reshape(images, [images.shape[0], -1])
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    return

def preprocess_img(x):
    x = transforms.ToTensor()(x)    # x (0., 1.)
    return (x - 0.5) / 0.5          # x (-1., 1.)

def deprocess_img(x):           # x (-1., 1.)
    return (x + 1.0) / 2.0      # x (0., 1.)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_set = mnist.MNIST('./data', train=True, transform=preprocess_img, download=False)
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=False)

def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer

def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)       #设定最小为0 ,最大为1
    x = x.view(x.shape[0], 1, 28, 28)
    return x

##定义训练
def train_gen(D_net, G_net,
              D_optimizer, G_optimizer,
             discriminator_loss, generator_loss,
             num_epochs=10, noise_size=96, num_img=6):
    f, a = plt.subplots(num_img, num_img, figsize=(num_img, num_img))
    plt.ion()  # Turn the interactive mode on, continuously plot

    for epoch in range(num_epochs):
        print()
        for iteration, (ima, _) in enumerate((train_data)):
            bs = ima.shape[0]
            ##判决网络
            real_data = torch.autograd.Variable(ima).view(bs, -1).to(device)# 真实数据
            logits_real = D_net(real_data)      # 判别网络得分

            sample_noise = (torch.rand(bs, noise_size) - 0.5 ) / 0.5 # -1 ~ 1 的均匀分布
            g_fake_seed = torch.autograd.Variable(sample_noise).to(device)# 假数据
            fake_images = G_net(g_fake_seed)    # 生成的假的数据
            logits_fake = D_net(fake_images)    # 判别网络得分

            ## 判决器的反向传播
            d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step()                  # 优化判别网络

            ## 生成网络
            g_fake_seed = torch.autograd.Variable(sample_noise).to(device)# 假数据
            fake_images = G_net(g_fake_seed)    # 生成的假的数据

            gen_logits_fake = D_net(fake_images)    ## 放进判决器 看判决器是否能识别出来
            g_error = generator_loss(gen_logits_fake)# 生成网络的 loss

            ##生成网络的反向传播
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step()  # 优化判别网络

            ## 每隔 20 次画出生成的图片
            if iteration % 20 == 0 :
                print(f'Epoch: {epoch + 1} | Iter: {iteration} | '
                      f'D_loss: {d_total_error.cpu().data.numpy()} | '
                      f'G_loss:{g_error.cpu().data.numpy()}')
                im_gen = deprocess_img(fake_images.data.cpu().numpy())
                for i in range(num_img ** 2):
                    a[i // num_img][i % num_img].imshow(np.reshape(im_gen[i], (28, 28)), cmap='gray')
                    a[i // num_img][i % num_img].set_xticks(())
                    a[i // num_img][i % num_img].set_yticks(())
                plt.suptitle('epoch: {} iteration: {}'.format(epoch, iteration))
                plt.pause(0.01)
        pic = to_img(fake_images.cpu().data)
        torchvision.utils.save_image(pic, f'./out/ima_{epoch + 1}.png')


D_net = six_Net.discriminator().to(device)
G_net = six_Net.generator(NOISE_DIM).to(device)
D_optimizer = get_optimizer(D_net)
G_optimizer = get_optimizer(G_net)
discriminator_loss = six_Net.discriminator_loss
generator_loss = six_Net.generator_loss

train_gen(D_net, G_net,
              D_optimizer, G_optimizer,
             discriminator_loss, generator_loss,
          10, 96, 5)

Summarize:

The network consists of two small networks, one for discrimination and one for generation

Discrimination network:

First input the real picture into the discriminator, and get the real probability of 1 bit

Randomly generate a set of data, send it to the generator, and get fake data

Then send the fake data to the discriminator to get the false probability

Finally, send the real probability + false probability into the loss function calculation

backpropagation

Generate network:

Randomly generated data, sent to the generator to get false data

Send the false data to the discriminator to see if it can be identified, and the probability obtained is

Finally sent to the loss function calculation

backpropagation

It is worth noting that if the color character picture received by the tee is generated, the BN layer in the discriminator must be removed, otherwise the generated picture cannot be seen by the human eye at all.

Guess you like

Origin blog.csdn.net/qq_42792802/article/details/126164883