生成的対立ネットワーク(GAN)の詳細な説明と例

GANの紹介

GAN を理解する直感的な方法は、ゲーム理論の観点から理解することです。GAN は、ジェネレーターとディスクリミネーターの 2 人のプレーヤーで構成され、それぞれが相手を打ち負かそうとします。ジェネレーターはパーティションからランダム ノイズを取得し、そこから出力のような分布を生成しようとします。ジェネレーターは常に、真の分布と見分けがつかない分布を作成しようとします。つまり、偽の出力は実際の画像のように見えます。ただし、明示的なトレーニングやラベル付けがなければ、ジェネレーターは実画像を区別できず、その唯一のソースはランダムな浮動小数点数のテンソルです。

その後、GAN はゲーム内の別のプレイヤー、ディスクリミネーターを紹介します。ディスクリミネーターは、生成する出力が実画像のように見えないことをジェネレーターに通知することのみを担当するため、ジェネレーターは画像の生成方法を変更して、それが実画像であることをディスクリミネーターに納得させます。しかし、ディスクリミネーターは、画像がジェネレーターから生成されたことを認識しているため、生成した画像が本物ではないことを常にジェネレーターに伝えることができます。この問題を解決するために、GAN は実際の画像をゲームに導入し、ディスクリミネーターをジェネレーターから分離します。ここで、ディスクリミネーターは一連の本物の画像とジェネレーターからの偽の画像から画像を取得し、各画像がどこから来たのかを把握する必要があります。

最初は、識別器は何も知りませんが、ランダムに結果を予測します。ただし、弁別器のタスクは分類タスクに変更できます。弁別器は、入力画像を元の画像または生成された画像として分類できます。これはバイナリ分類です。同様に、画像を正しく分類するように弁別器ネットワークをトレーニングし、最終的に逆伝播を通じて、弁別器は実際の画像と生成された画像を区別することを学習します。

ここに画像の説明を挿入

コード例

データセットの紹介:
この実験では、花のデータセットを選択して画像を生成します.このデータセットには 6 つのカテゴリがあります.
ここに画像の説明を挿入

モデルのトレーニング
トレーニング ディスクリミネーター:
本物の写真の場合、出力はできるだけ 1 にする必要があります
偽の写真の場合、出力はできるだけ 0 にする必要があります
トレーニング ジェネレーター:
偽の写真の場合、出力はできるだけ 1 にする必要があります
1.ジェネレーターをトレーニングする場合、ディスクリミネーターのパラメーターを調整する必要はありません; ディスクリミネーターをトレーニングする場合、ジェネレーターのパラメーターを調整する必要はありません。
2. ディスクリミネーターをトレーニングするときは、デタッチ操作を使用して、ジェネレーターによって生成された画像の計算グラフを切り捨て、バックプロパゲーションが勾配をジェネレーターに渡さないようにする必要があります。ディスクリミネーターをトレーニングするときにジェネレーターをトレーニングする必要がないため、ジェネレーターの勾配は必要ありません。
3. ディスクリミネータをトレーニングする場合、実際の画像を 1 と判断するために 1 回、偽の画像を 0 と判断するために 1 回、合計 2 回のバックプロパゲーションが必要です。2 つのデータをまとめて、1 つの順伝播と 1 つの逆伝播を実行することも可能です。
4. フェイク画像の場合、ディスクリミネーターをトレーニングするときは 0 を出力する必要があり、ジェネレーターをトレーニングするときは 1 を出力する必要があります。 =基準(出力、true_labels)。ディスクリミネーターは偽の画像を偽のラベルとして識別できることを望み、ジェネレーターはそれを真のラベルとして識別できることを望んでおり、ディスクリミネーターとジェネレーターはお互いに改善するために戦います。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from dataloader import MyDataset
from model import Generator, Discriminator
import torchvision
import numpy as np
import matplotlib.pyplot as plt
if __name__ == '__main__':
    LR = 0.0002
    EPOCH = 1000  # 50
    BATCH_SIZE = 40
    N_IDEAS = 100
    EPS = 1e-10
    TRAINED = False
    #path = r'./data/image'
    train_data = MyDataset(path=path, resize=96, Len=10000, img_type='jpg')
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

    torch.cuda.empty_cache()

    if TRAINED:
        G = torch.load('G.pkl').cuda()
        D = torch.load('D.pkl').cuda()
    else:
        G = Generator(N_IDEAS).cuda()
        D = Discriminator(3).cuda()

    optimizerG = torch.optim.Adam(G.parameters(), lr=LR)
    optimizerD = torch.optim.Adam(D.parameters(), lr=LR)

    for epoch in range(EPOCH):
        tmpD, tmpG = 0, 0
        for step, x in enumerate(train_loader):
            x = x.cuda()
            rand_noise = torch.randn((x.shape[0], N_IDEAS, 1, 1)).cuda()
            G_imgs = G(rand_noise)

            D_fake_probs = D(G_imgs)
            D_real_probs = D(x)

            p_d_fake = torch.squeeze(D_fake_probs)
            p_d_real = torch.squeeze(D_real_probs)

            D_loss = -torch.mean(torch.log(p_d_real + EPS) + torch.log(1. - p_d_fake + EPS))
            G_loss = -torch.mean(torch.log(p_d_fake + EPS))
            # D_loss = -torch.mean(torch.log(D_real_probs) + torch.log(1. - D_fake_probs))
            # G_loss = torch.mean(torch.log(1. - D_fake_probs))

            optimizerD.zero_grad()
            D_loss.backward(retain_graph=True)
            optimizerD.step()

            optimizerG.zero_grad()
            G_loss.backward(retain_graph=True)
            optimizerG.step()

            tmpD_ = D_loss.cpu().detach().data
            tmpG_ = G_loss.cpu().detach().data
            tmpD += tmpD_
            tmpG += tmpG_
        tmpD /= (step + 1)
        tmpG /= (step + 1)
        print(
            'epoch %d avg of loss: D: %.6f, G: %.6f' % (epoch, tmpD, tmpG)
        )
        # if (epoch+1) % 5 == 0:
        select_epoch = [1, 5, 10, 20, 50, 80, 100, 150, 200, 400, 500, 800, 999, 1500, 2000, 3000, 4000, 5000, 6000, 8000, 9999]
        if epoch in select_epoch:
plt.imshow(np.squeeze(G_imgs[0].cpu().detach().numpy().transpose((1, 2, 0))) * 0.5 + 0.5)
            plt.savefig('./result1/_%d.png' % epoch)

    torch.save(G, 'G.pkl')
    torch.save(D, 'D.pkl')

以下は、トレーニングを複数回行った場合の結果です。
ここに画像の説明を挿入
ここに画像の説明を挿入
ここに画像の説明を挿入
ここに画像の説明を挿入
ここに画像の説明を挿入
ここに画像の説明を挿入
完全なコードは次のとおりです。

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False

# dir = '... your path/faces/'
dir = './data/train_data'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223


noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 50
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generator


class NetGenerator(nn.Module):
    def __init__(self):
        super(NetGenerator,self).__init__()
        self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行
            nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(n_generator_feature * 8),
            nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4
            nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature * 4),
            nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8
            nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature * 2),
            nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16
            nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature),
            nn.ReLU(True),      # (n_generator_feature) × 32 × 32
            nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),
            nn.Tanh()       # 3 * 96 * 96
        )

    def forward(self, input):
        return self.main(input)


class NetDiscriminator(nn.Module):
    def __init__(self):
        super(NetDiscriminator,self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32
            nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 2),
            nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16
            nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 4),
            nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8
            nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4
            nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()        # 输出一个概率
        )

    def forward(self, input):
        return self.main(input).view(-1)


def train():
    for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96
        real_image = Variable(image)
        real_image = real_image.cuda()

        if (i + 1) % d_every == 0:
            optimizer_d.zero_grad()
            output = Discriminator(real_image)      # 尽可能把真图片判为True
            error_d_real = criterion(output, true_labels)
            error_d_real.backward()

            noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))
            fake_img = Generator(noises).detach()       # 根据噪声生成假图
            fake_output = Discriminator(fake_img)       # 尽可能把假图片判为False
            error_d_fake = criterion(fake_output, fake_labels)
            error_d_fake.backward()
            optimizer_d.step()

        if (i + 1) % g_every == 0:
            optimizer_g.zero_grad()
            noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))
            fake_img = Generator(noises)        # 这里没有detach
            fake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为True
            error_g = criterion(fake_output, true_labels)
            error_g.backward()
            optimizer_g.step()


def show(num):
    fix_fake_imags = Generator(fix_noises)
    fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5

    # x = torch.rand(64, 3, 96, 96)
    fig = plt.figure(1)

    i = 1
    for image in fix_fake_imags:
        ax = fig.add_subplot(8, 8, eval('%d' % i))
        # plt.xticks([]), plt.yticks([])  # 去除坐标轴
        plt.axis('off')
        plt.imshow(image.permute(1, 2, 0))
        i += 1
    plt.subplots_adjust(left=None,  # the left side of the subplots of the figure
                        right=None,  # the right side of the subplots of the figure
                        bottom=None,  # the bottom of the subplots of the figure
                        top=None,  # the top of the subplots of the figure
                        wspace=0.05,  # the amount of width reserved for blank space between subplots
                        hspace=0.05)  # the amount of height reserved for white space between subplots)
    plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)
    plt.savefig("images/%dcgan.png" % num)


if __name__ == '__main__':
    transform = tv.transforms.Compose([
        tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecated
        tv.transforms.CenterCrop(96),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数
    ])

    dataset = tv.datasets.ImageFolder(dir, transform=transform)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'

    print('数据加载完毕!')
    Generator = NetGenerator()
    Discriminator = NetDiscriminator()

    optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    criterion = torch.nn.BCELoss()

    true_labels = Variable(torch.ones(batch_size))     # batch_size
    fake_labels = Variable(torch.zeros(batch_size))
    fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))
    noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布

    if torch.cuda.is_available() == True:
        print('Cuda is available!')
        Generator.cuda()
        Discriminator.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()


    plot_epoch = [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]

    for i in range(3000):        # 最大迭代次数
        train()
        print('迭代次数:{}'.format(i))
        if i in plot_epoch:
            show(i)


おすすめ

転載: blog.csdn.net/weixin_45807161/article/details/123776427