基于MNIST的GANs实现【Pytorch】

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/a19990412/article/details/83956498

简述

其实是根据我之前写的两个代码改的。(之前已经有过非常详细的解释了,可以去看看)

同时,在结合了我之前写的DCGANs的时候,实现的一份代码

MNIST上选特定的数值,是根据下面的这篇文章得到的。

之前的代码上都有非常详细的解释。这里只是基于上面的一点点改进而已。就不给出特别详细的解释。但是代码中任然保留有注释部分。

图形演变过程

在这里插入图片描述

代码

import torch
import torch.nn as nn
import torchvision
import torch.utils.data as Data
import matplotlib.pyplot as plt
import os
import shutil
import imageio
PNGFILE = './png/'
if not os.path.exists(PNGFILE):
    os.mkdir(PNGFILE)
else:
    shutil.rmtree(PNGFILE)
    os.mkdir(PNGFILE)

# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001  # learning rate for generator
LR_D = 0.0001  # learning rate for discriminator
N_IDEAS = 100  # think of this as number of ideas for generating an art work (Generator)
target_num = 0  # target Number
EPOCH = 10  # 训练整批数据多少次
DOWNLOAD_MNIST = False  # 已经下载好的话,会自动跳过的
ART_COMPONENTS = 28 * 28


# Mnist 手写数字

class myMNIST(torchvision.datasets.MNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, targetNum=None):
        super(myMNIST, self).__init__(
            root,
            train=train,
            transform=transform,
            target_transform=target_transform,
            download=download)
        if targetNum != None:
            self.train_data = self.train_data[self.train_labels == targetNum]

            self.train_data = self.train_data[:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]

            self.train_labels = self.train_labels[self.train_labels == targetNum][
                                :int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]

    def __len__(self):
        if self.train:
            return self.train_data.shape[0]
        else:
            return 10000


train_data = myMNIST(
    root='./mnist/',  # 保存或者提取位置
    train=True,  # this is training data
    transform=torchvision.transforms.ToTensor(),  # 转换 PIL.Image or numpy.ndarray 成
    # torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间
    download=DOWNLOAD_MNIST,  # 没下载就下载, 下载了就不用再下了
    targetNum=target_num
)
print(len(train_data))
# print(train_data.shape)


# 训练集丢BATCH_SIZE个, 图片大小为28*28
train_loader = Data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True  # 是否打乱顺序
)

G = nn.Sequential(  # Generator
    nn.Linear(N_IDEAS, 128),  # random ideas (could from normal distribution)
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),  # making a painting from these random ideas
    nn.ReLU(),
)

D = nn.Sequential(  # Discriminator
    nn.Linear(ART_COMPONENTS, 128),  # receive art work either from the famous artist or a newbie like G
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),  # tell the probability that the art work is made by artist
)

# loss & optimizer
optimD = torch.optim.Adam(D.parameters(), lr=LR_D)
optimG = torch.optim.Adam(G.parameters(), lr=LR_G)

label_Real = torch.FloatTensor(BATCH_SIZE).data.fill_(1)
label_Fake = torch.FloatTensor(BATCH_SIZE).data.fill_(0)

filePath = []

for epoch in range(EPOCH):
    for step, (images, imagesLabel) in enumerate(train_loader):
        G_ideas = torch.randn((BATCH_SIZE, N_IDEAS))
        G_paintings = G(G_ideas)
        images = images.reshape(BATCH_SIZE, -1)
        prob_artist0 = D(images)  # D try to increase this prob
        prob_artist1 = D(G_paintings)

        D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
        G_loss = torch.mean(torch.log(1. - prob_artist1))

        optimD.zero_grad()
        D_loss.backward(retain_graph=True)
        optimD.step()

        optimG.zero_grad()
        G_loss.backward(retain_graph=True)
        optimG.step()

        if step % 20 == 0:
            plt.cla()
            picture = torch.squeeze(G_paintings[0]).detach().numpy().reshape((28, 28))
            plt.imshow(picture, cmap=plt.cm.gray_r)
            plt.savefig(PNGFILE + '%d-%d.png' % (epoch, step))
            filePath.append(PNGFILE + '%d-%d.png' % (epoch, step))

generated_images = []
for png_path in filePath:
    generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan-mnist.gif', generated_images, 'GIF', duration=0.1)

猜你喜欢

转载自blog.csdn.net/a19990412/article/details/83956498
今日推荐