GAN はアニメのアバターを生成します

GAN はアニメのアバターを生成します

この記事は主に Li Honyi 氏の GAN コースに基づいており、「深層学習フレームワーク pytorch の紹介と実践」のコードをいくつか組み合わせています。

GAN 原理の概要

GAN (敵対的生成ネットワーク) は対立ネットワークを生成します GAN は有名な問題を解決します:サンプルのバッチが与えられた場合、システムをトレーニングすると同様の新しいサンプルを生成できます生成的対立ネットワークは、その名前が示すように、生成者 (Generator) と弁別者 (Discriminator) の 2 つの部分から構成されます。

  • ジェネレーター: ランダムなノイズを入力して画像を生成します。
  • Discriminator(ディスクリミネーター):画像が本物か偽物かを判別する

画像の説明を追加してください

基本的な学習プロセスは図に示されています. 第一世代の生成器は最初にランダムに初期化され, ランダムなノイズが通過し, 生成された画像はぼやけています. 第一世代の識別器が行うことは, 識別できることです.画像が最初の画像であるか、その世代のジェネレーターによって生成された画像であるか、実際の画像であるか。おそらく第一世代の識別器はその色が本物の絵だと思っているのでしょうが、生成器は識別器を騙さなければならないので進化しなければならず、カラーの絵を生成するために第二世代に進化しました。識別器もそれに応じて進化し、口のあるものが本物であることを発見し、生成された画像と本物の画像を区別します。そしてジェネレータも第3世代に進化し、口の画像を生成して第2世代の識別器を騙し、識別器も第3世代に進化しました。お互いに段階的に戦い、進化し、最終的に二次元アバターを生成します。

アルゴリズム処理

ディスクリミネーターをトレーニングする

まず、ジェネレーターとディスクリミネーターを初期化し、各トレーニング反復で、まずジェネレーターを修正し、ジェネレーターにランダム ノイズを渡し、対応する画像を生成します。前述したように、ディスクリミネーターの役割は、本物の画像と偽の画像を区別することです。ディスクリミネーターは主に画像にスコアを付けます。スコアが高いほど、本物の画像である可能性が高く、スコアが低いほど、本物の画像である可能性が高くなります。偽の写真。したがって、ディスクリミネーターを学習させる際には、データベースからのグラフとジェネレーターからのグラフを受け取り、パラメーターを調整し、データベースからのグラフであれば高スコアを与え、データベースからのグラフであれば高いスコアを与えます。ジェネレーターでは、低いスコアが得られます。言い換えれば、私たちのトレーニング場所での識別子は、データセットによって与えられたグラフが 1 に近いほど優れており、ジェネレーターによって与えられたグラフが 0 に近いほど優れているということです。

画像の説明を追加してください

トレーニングジェネレーター

ディスクリミネーターは以前にトレーニングされており、今回ジェネレーターがトレーニングされました。ジェネレーターが行う必要があるのは、ランダム ノイズを受信し、画像を生成し、ディスクリミネーターを「騙す」ことです。ジェネレーターを「騙す」方法。実際のアプローチは、生成された画像をディスクリミネーターに送信してスコアリングを行うことであり、目標はスコアをできるだけ高くすることです (1 に近いほど良い)。パート全体は、ノイズを入力してスコアを生成する Generator と Discriminator を含む巨大な隠れ層である全体とみなすことができます。スコアを 1 に近づけるようにパラメータを調整しますが、Discriminator の中央部分は調整できません。

画像の説明を追加してください

プロセス全体

画像の説明を追加してください

  1. データセットから m 個のサンプル {x1,x2...xm} を選択し、m 個のノイズを作成します。z の次元は自分で決定します。これは後でジェネレーターによって入力されるノイズです。生成されたデータを取得します。これが G です。 (z) ジェネレータ 生成; 次の式が最大になるように識別器パラメータを更新します。次の式は次のことを意味します: D(x) 実画像を識別するための識別器のスコアを対数平均に 1 - 値の対数平均を加えた値偽の画像値を区別するための識別器の値。簡単に言うと、本物の写真を区別するスコアは高いほど良く、偽の写真を区別する場合は、偽の写真のスコアが値 1 から離れているほど良いことになります。( bceloss を使用したバイナリ分類器のトレーニングと同等)
  2. m 個のノイズ ポイント z、ジェネレーターを更新すると、ジェネレーターは z によって生成された画像をディスクリミネーターにフィードします。スコアが高いほど優れています。

紀元前の喪失

bce loss分类,用于二分类问题。数学公式如下
l o s s ( X i , y i ) = − w i [ y i l o g x i + ( 1 − y i ) l o g ( 1 − x i ) ] loss(X_i,y_i) = -w_i[y_ilogx_i + (1 - y_i)log(1 - xi)] ロス( X _ _ _私はy私は)=−w_私は[ y私はl o g x私は+( 1y私は) l o g ( 1x i ) ]
pytorch中bceloss

class torch.nn.BCELoss(weight: Optional[torch.Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean')

重み: 重み行列を初期化します。

size_average: デフォルトは True、損失を平均化します。

削減: デフォルトの合計、batch_size の平均損失

コード

前の理論を理解した後は、ここで使用するデータセットがExtra Dataであることを理解し始めることができます。また、 Anime Datasetを試してみることもできますはしごは自分で見つけてください。

初期化

import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.datasets as DataSet
import torchvision.transforms as transform
import torch.utils.data as Data
import numpy as np
import torch.nn as nn
import torch.optim as optim
import os


# 用于图片保存
def saveImg(inp, name):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    plt.savefig(name)

    
# 用于图片显示,可以调试数据集是否加载成功
def imgshow(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    plt.show()


    
batch_size = 20
# 图像处理,尺寸转为64 * 64,转tensor范围(0,1), Normalize之后转为 (-1, 1)
simple_transform = transform.Compose([
    transform.Resize((64, 64)),
    transform.ToTensor(),
    transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 使用GPU or CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 噪点数量
noise_z = 100
generator_feature_map = 64
# 加载数据集
path = "AnimeDataset"
train_set = DataSet.ImageFolder(path, simple_transform)
train_loader = Data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
# 正确分数标签
true_label = torch.ones(batch_size).to(device)
true_label = true_label.view(-1, 1)
# 错误分数标签
false_label = torch.zeros(batch_size).to(device)
false_label = false_label.view(-1, 1)
# 固定的noises,这样在每个Epoch完成之后可以看到generator产生同个照片的过程
fix_noises = torch.randn(batch_size, noise_z, 1, 1).to(device)
# 随机noises
noises = torch.randn(batch_size, noise_z, 1, 1).to(device)
g_train_cycle = 1  # 训练生成器周期
save_img_cycle = 1  # 每几次epoch输出一次结果
print_step = 200  # 打印loss 信息周期

bceloss = nn.BCELoss()

ビルダー

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(
            # 100*1*1 --> (64 * 8) * 4 *4
            nn.ConvTranspose2d(noise_z, generator_feature_map * 8, kernel_size=4, bias=False),
            nn.BatchNorm2d(generator_feature_map * 8),
            nn.ReLU(True))
        self.layer2 = nn.Sequential(
            # (64 * 8) * 4 * 4 --> (64 * 4)*8*8
            nn.ConvTranspose2d(generator_feature_map * 8, generator_feature_map * 4, kernel_size=4, stride=2,
                               padding=1),
            nn.BatchNorm2d(generator_feature_map * 4),
            nn.ReLU(True))
        self.layer3 = nn.Sequential(

            # (64*4)*8*8 --> (64*2)*16*16
            nn.ConvTranspose2d(generator_feature_map * 4, generator_feature_map * 2, kernel_size=4, stride=2, padding=1,
                               bias=False),
            nn.BatchNorm2d(generator_feature_map * 2),
            nn.ReLU(True))
        self.layer4 = nn.Sequential(

            # (64*2)*16*16 --> 64*32*32
            nn.ConvTranspose2d(generator_feature_map * 2, generator_feature_map, kernel_size=4, stride=2, padding=1,
                               bias=False),
            nn.BatchNorm2d(generator_feature_map),
            nn.ReLU(True))
        self.layer5 = nn.Sequential(
            # 64*32*32 --> 3*64*64
            nn.ConvTranspose2d(generator_feature_map, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        return out

識別子

class Discriminator(nn.Module):
    def __init__(self, ndf=64):
        super(Discriminator, self).__init__()
        # layer1 输入 3 x 96 x 96, 输出 (ndf) x 32 x 32
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # layer2 输出 (ndf*2) x 16 x 16
        self.layer2 = nn.Sequential(
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # layer3 输出 (ndf*4) x 8 x 8
        self.layer3 = nn.Sequential(
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # layer4 输出 (ndf*8) x 4 x 4
        self.layer4 = nn.Sequential(
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # layer5 输出一个数(概率)
        self.layer5 = nn.Sequential(
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    # 定义NetD的前向传播
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = out.view(-1,1)
        return out

オプティマイザ

generator = Generator().to(device)
discriminator = Discriminator().to(device)

learning_rate = 0.0002
beta = 0.5
# 优化器初始化
g_optim = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta, 0.999))
d_optim = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta, 0.999))

損失関数

def loss_g_func(testLabel, trueLabel):
    return bceloss(testLabel, trueLabel)


def loss_d_func(real_predicts, real_labels, fake_predicts, fake_labels):
    real = bceloss(real_predicts, real_labels)	# 真图片分数,不断靠近1
    fake = bceloss(fake_predicts, fake_labels)	# 假图片分数,不断靠近0
    real.backward()
    fake.backward()
    return real + fake

トレーニングを開始する

# 训练Discriminator
train_num = 25
for trainIdx in range(train_num):
    for step, data in enumerate(train_loader):
        image_x, _ = data
        image_x = image_x.to(device)
        # 训练判别器
        noises.data.copy_(torch.randn(batch_size, noise_z, 1, 1))
        out = discriminator(image_x)	# 原图产生的分数
        fake_pic = generator(noises)	# 生成器生成图像
        fake_predict = discriminator(fake_pic.detach())	# 使用detach()切断求导关联
        d_optim.zero_grad()
        dloss = loss_d_func(out, true_label, fake_predict, false_label)
        d_optim.step()

        if step % g_train_cycle == 0:
            # 训练生成器
            g_optim.zero_grad()
            noises.data.copy_(torch.randn(batch_size, noise_z, 1, 1))
            fake_img = generator(noises)
            fake_out = discriminator(fake_img)
            # 尽可能让判别器把假图判别为1
            loss_fake = loss_g_func(fake_out, true_label)
            loss_fake.backward()
            g_optim.step()

        if step % print_step == print_step - 1:
            print("train: ", trainIdx, "step: ", step + 1, " d_loss: ", dloss.item(), "mean score: ",
                  torch.mean(out).item())
            print("train: ", trainIdx, "step: ", step + 1, " g_loss: ", loss_fake.item(), "mean score: ",
                  torch.mean(fake_out).item())

    if trainIdx % save_img_cycle == 0:
        fix_fake_image = generator(fix_noises)
        fix_fake_image = fix_fake_image.data.cpu()
        comb_img = torchvision.utils.make_grid(fix_fake_image, nrow=4)
        savepath = os.path.join("gan", "pics", "g_%s.jpg" % trainIdx)
        saveImg(comb_img, savepath)
        torch.save(discriminator.state_dict(), './gan/netd_%s.pth' % trainIdx)
        torch.save(generator.state_dict(), './gan/netg_%s.pth' % trainIdx)

結果

これは1エポックの効果です

画像の説明を追加してください

5 エポック

画像の説明を追加してください

10エポック

画像の説明を追加してください

25 エポック

画像の説明を追加してください

おすすめ

転載: blog.csdn.net/qq_36571422/article/details/123883196