GANの原理とGANを実現するためのPytorchフレームワーク(比較的わかりやすい)

目次

1. GAN を初めて知る

GANとは?

GAN アプリケーションのシナリオ

2. GANの原理構造

(1) 生成的対立ネットワークサブネットワーク

(2) 構造図

(1) 発電機 

(2) 識別器

(3) トレーニングスキル 

3. GAN ネットワーク モデルの選択

(1) モデルの生成

(2) 判別モデル

4. GAN トレーニングの目的関数

(1) モデルの生成

(2) 判別モデル

5. トレーニングアルゴリズム

6.GAN コードの実装

7. mainWindow ウィンドウには、ジェネレーターによって生成された画像が表示されます。

拡大


1. GAN を初めて知る

  • GANとは?

    • GAN (Generative Adversarial Networks): 対立ネットワークを生成します。
    • GAN は、人工知能の分野で最も重要な研究ホットスポットの 1 つであり、広く使用されています。
  • 2014 年、モントリオール大学 Yoshua Bengio の学生である Ian Goodfellow (2018 年 Turing Award 受賞者) は、Generative adj-terminal network (略して GAN) を提案し、深層学習の最もホットな研究の方向性を切り開きました。
  • 2014年から2019年にかけて、GANの研究は着々と進んでおり、研究成果は度々出ており、最新のGANアルゴリズムが画像生成に与える影響は、肉眼では見分けがつかないレベルにまで達しています。GAN の発明により、イアン・グッドフェローは GAN の父の称号を授与され、2017 年の MIT Technology Review Awards で 35 Innovators Goodfellow Award を受賞しました。
  • この方法は、生成ネットワークと識別ネットワークと呼ばれる 2 つのネットワークを利用し、オーディオ、ビデオ、およびテキストの形式で非常にクリエイティブな出力を生成するために使用できます。彼の研究は、人工知能に関する文献で広く引用されています。
  • GAN アプリケーションのシナリオ

    • 画像編集: 画像を指定すると、その画像に基づいてさまざまな画像を生成できます。
    • 悪意のある攻撃の検出: ディープ ラーニングによって生成されたモデルは、ハッキング、悪用、さらには制御される可能性があります。このような敵対的攻撃に対抗するために、敵対的ニューラル ネットワークをトレーニングして、架空の敵としてより多くの偽のトレーニング データを生成し、モデルが演習中にこれらの偽のデータを識別できるようにし、GAN によって生成された偽のデータにより、分類モデルをより多くすることができます。屈強;
    • データ生成: たとえば、医療分野では、トレーニング データの不足がディープ ラーニングの適用に対する最大の障害です。従来のデータ強化の方法は、元の画像を引き伸ばしたり、回転させたり、切り取ったりすることでしたが、これはやはり元の画像であり、GAN を使用することで、より類似したデータを生成することができます。
    • 注意予測: 人間が写真を見るとき、特定の部分だけに注意を向ける場合が多く、GAN モデルを通じて、人間の関心領域がどこにあるかを予測できます。
    • 三次元構造の生成: pix2vox は、GAN に基づくオープン ソース ツールであり、対応する形状だけでなく、対応する色を使用して、手描きの 2 次元画像に基づいて対応する 3 次元構造を生成できます。 3D モデリングのしきい値を下げることができるため、3D プリントが着陸しやすくなります。

2. GANの原理構造

ヒント: 以下の原則の説明は読者にとって退屈かもしれませんが、読者が原則をしっかりと読んでくれることを願っています。このようにしてのみ、GAN の実装原理を真に理解できるからです。

(1) 生成的対立ネットワークサブネットワーク

 GANには以下が含まれます:ジェネレーターネットワークとディスクリミネーターネットワーク。ジェネレーターネットワークGenはサンプルの実際の分布を学習し、判別ネットワークDisはジェネレーターネットワークによって生成されたサンプルと実際のサンプルを区別します。

(2) 構造図

(1) 発電機 

生成モデルはランダム ノイズまたは類似の制御変数を入力として取り、ジェネレータは一般に多層ニューラル ネットワークを使用して実装され、その出力は生成されたサンプル、つまり偽の画像です。そのようなサンプルと実際に与えられたサンプルは次のとおりです。識別モデルと一緒にトレーニングされます。

"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/24 14:21
"""
import torch
import numpy as np
#对于生成器,输入的为正态分布随机数
#输出为: [1,28,28]图片

class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(in_features=100,out_features=256),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=256,out_features=512),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=512,out_features=784),
            torch.nn.Tanh()#对于生成器使用tanh激活函数更好
        )
    def forward(self,input):
        x = self.fc(input)
        img = x.view(-1,28,28)
        return img

(2) 識別器

弁別モデルは、サンプルが実際のサンプルかジェネレーターによって生成されたサンプルかを区別するバイナリ分類器であり、一般にニューラル ネットワークを使用して実装されます。

"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/24 14:21
"""
import torch
import numpy as np

#判别器的输入为一张图片
#输出为二分类的概率值
#判别器对log(1 - D(G(z)))的判别作为生成器的损失值

class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(in_features=784,out_features=512),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_features=512,out_features=256),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_features=256,out_features=1),
            torch.nn.Sigmoid()
        )
    def forward(self,input):
        x = input.view(-1,784)
        x = self.fc(x)
        return x

(3) トレーニングスキル 

  • 生成モデルの場合: トレーニングの目標は、生成されたデータを実際のデータに可能な限り類似させ、識別モデルの識別精度を最小限に抑えることです。
  • 識別モデルの場合: トレーニングの目標は、識別精度を最大化することです。つまり、サンプルが実際のサンプルかジェネレーターによって生成されたサンプルかを区別することです。

このプロセスは矛盾していることがわかるので、次のようになります。

  • トレーニングの過程で、代替最適化の方法が採用され、各反復は 2 つの段階に分けられます。
    • 第 1 段階: まず判別モデルを修正し、生成モデルを最適化して、生成されたデータが判別モデルによって真のサンプルと判断される確率ができるだけ高くなるようにします。
    • 第 2 段階: 生成モデルを修正し、ディスク モデルを最適化し、判別モデルの分類精度を向上させます。

ヒント: トレーニング プロセスの間、ジェネレーターは生成された画像をよりリアルにするために懸命に働きますが、ディスクリミネーターはジェネレーターの画像の信頼性を識別するために懸命に働きます. これは相互ゲームのプロセスであり、お互いが自分自身を向上させます, つまり,彼らは常に対立のプロセスです。トレーニングが進むにつれて、生成モデルによって生成されたサンプルと実際のサンプルの間にほとんど差がなくなり、判別モデルはサンプルが真か偽かを正確に判断できなくなります。このとき、分類エラー率は 0.5 (Nash平衡)

3. GAN ネットワーク モデルの選択

生成対立ネットワークは抽象的なフレームワークであり、どのモデルが生成モデルであるか判別モデルであるかは特定されていません. ニューラル ネットワーク モデル、畳み込みニューラル ネットワーク モデル、またはその他の機械学習モデルのいずれかです。

(1) モデルの生成

        この論文では、選択された怒っているモデルはニューラル ネットワーク モデルです。画像などのサンプルデータを生成するタイプなどの入力変数に応じて、生成された入力モデルは、カテゴリやランダムノイズなどの隠れ変数を受け取り、トレーニングサンプルに類似したサンプルデータ (画像など) を出力します。

(2) 判別モデル

        判別モデルは通常、分類問題のニューラル ネットワークを使用して、サンプル (実際のデータとジェネレーターによって生成されたデータが与えられた場合) の真と偽を区別します。これは 2 カテゴリの問題です。

4. GAN トレーニングの目的関数

ヒント: 生成モデルと判別モデルを決定する前に、まずロジスティック回帰モデルを理解してください。

ロジスティック回帰、または対数確率回帰は、バイナリ分類問題の分類アルゴリズムです. シグモイド関数は、サンプルが陽性サンプルに属する確率を推定するために使用されます (詳細な導出については、「機械学習の原則、アルゴリズム」を読むことをお勧めします、およびアプリケーション」)。

ロジスティック回帰尤度関数:

  • 回帰対数関数と世代対立の違い:
    • ロジスティック回帰がトレーニングの最適点に達すると、負のサンプルの予測出力は 0 に近くなります。
    • 敵対的生成ネットワークにおける識別モデルの敵対的サンプルの出力確率値は、最適点で 0.5 に近くなります。 

(1) モデルの生成

(2) 判別モデル

5. トレーニングアルゴリズム

  

6.GAN コードの実装

ヒント: コードは Github に配置されており、読者は自分でダウンロードできます: https://github.com/KeepTryingTo/Pytorch-GAN

 

7. mainWindow ウィンドウには、ジェネレーターによって生成された画像が表示されます。

ヒント: ジェネレーターによって表示される画像を表示するプログラム (mainWindow.py) は次のとおりです. 前回のトレーニング後に保存されたジェネレーター モデルをロードし、このモデルを使用して次のようにデジタル画像をランダムに生成します:

(1) mainWindow.py を実行 初期インターフェースは以下の通り

 

 (2) 画像の生成ボタンをクリックすると、各クリックで生成される数字は同じではありません。

 

 

拡大

pytorch での detach の役割

参考記事:

「TensorFlow ディープラーニング」

「機械学習の原理・アルゴリズム・応用」

https://www.jiqizhixin.com/articles/2019-04-15-6

https://b23.tv/6P7M8mh

おすすめ

転載: blog.csdn.net/Keep_Trying_Go/article/details/130362296