[Transfer Learning Paper 5] Generate To Adapt Aligning Domains using Generative Adversarial Networks 論文の原理と再現作業

適応するための生成: 敵対的生成ネットワークを使用したドメインの調整 適応するための生成: 敵対的生成ネットワークを使用したドメインの調整

序文

  • 長い間更新していなかったので、自分を追い込むために記録を始めました。
  • 大学院の準備中に、原理の説明と関連する転移学習論文の再現作業を記録します。

質問

記事紹介

  • この記事は 2018 年に CVPR に掲載され、Swami Sankaranarayanan、Yogesh Balaji、Carlos D. Castillo、Rama Chellappa によって書かれました。
  • 結合特徴空間: モデルによって学習されたソース ドメインとターゲット ドメインの間で共有される特徴表現は、ソース ドメインとターゲット ドメインの間でより適切に調整され、より適切に転送されます。
  • この記事の主な貢献は、結合特徴空間を直接学習できる敵対的画像生成のための教師なしドメイン適応方法を提案することです。以前の方法と比較したこの方法の独自性は、生成的アイデアと識別的アイデアの両方を使用し、画像生成の敵対的プロセスを使用して、ソース ドメインとターゲット ドメインの特徴分布を最小化する特徴空間を学習することです。
  • ソースドメインからのラベル付きデータとターゲットからのラベルなしデータを使用して、共有特徴の埋め込みを直接学習する敵対的画像生成方法が提案されています。
  • ジェネレーター G は、ソース ドメイン データに似た「偽の」データを生成する責任を負い、これらのデータが実際のソース ドメイン データに似ていることを期待します。識別器 D のタスクは、生成されたデータが本物か偽物かを識別することです。特徴抽出器 F の目標は、ソース ドメインとターゲット ドメインの両方に適応できる特徴表現を学習して、F によって抽出された特徴をタスクでより適切に使用できるようにし、生成されたデータをより適切なものにすることです。実際のデータなので、識別子がそれらを区別するのは困難です。

モデル構造

ここに画像の説明を挿入します

コンポーネントの役割の概要

  • 上の図では、F は特徴抽出器、C は分類器、G は生成器、D は識別器です。F はソース ドメインとターゲット ドメインの画像特徴を抽出するために使用され、C はソース ドメインのデータ ラベルを予測するために使用され、G は紛らわしい生成データを生成するために使用され、D は実際のデータと生成されたデータを識別するために使用されます。
  • モデルは 2 つのブランチに分かれています。
    • ラベル予測ネットワーク (ソース ドメイン データの処理) : ソース ドメイン データは F によって特徴付けられ、分類のために C に入力され、クロス エントロピー分類損失が計算されます。
    • 敵対的生成ネットワーク (転移学習とドメイン適応) : G は、混乱を招くデータ D の識別タスクを生成するために使用されます。G ジェネレーターは、抽出されたソース ドメイン データの特徴 (転置畳み込み、ラベルは 0) を処理して、ソース ドメインと類似しているが同一ではないデータを生成します。これらの生成されたデータは、転移学習で混同データとして使用されます。G は、ドメイン分類損失とラベルを計算しますこれらのソース ドメインで生成されたデータを D に入れて分類損失を計算します。D は、実際のデータと生成されたデータを区別する識別に使用されます。3 つの損失を計算します。まず、ソース ドメインの実データを直接 D に入力し、ドメイン分類損失を計算します。とラベル分類損失を計算し、G によって抽出されたターゲット ドメイン生成データが D によって計算されたドメイン分類損失に入れられます。最後に、G によって抽出されたソース ドメイン生成データは、D によって計算されたドメイン分類損失に入れられます。

特徴抽出器 F

  • ソース ドメインの実データの分類損失を最小限に抑える: F がソース ドメインの実データにより適した特徴表現を学習できるようにします

  • ソース ドメインで生成されたデータの分類損失を最小限に抑える: F がソース ドメインで生成されたデータの特性を理解し、学習できるようにします。

  • D がターゲット ドメインで生成されたデータが正しいと誤って信じているとします。生成されたデータがターゲット ドメインにより適したものになり、識別子 D が混乱します。

  • ソースドメインの実データラベル分類損失:
    ここに画像の説明を挿入します

  • ソース ドメインで生成されたデータ ラベル分類の損失:
    ここに画像の説明を挿入します

  • ターゲット ドメインはデータ ドメイン分類損失を生成します (混乱を避けるため、ドメイン ラベルは 1 です)。ここに画像の説明を挿入します

  • これらの損失関数の設計と最適化の目的は、特徴抽出器 F のパラメーターを調整して、ソース ドメイン データとターゲット ドメイン データによりよく適応できるようにすることにより、より優れたドメイン適応と転移学習効果を達成することです。ソース ドメイン ソース ドメインの実データ ドメイン データのドメイン分類損失は F から独立しています。

弁別子D

トレーニングを通じて、さまざまな種類のデータの信頼性と分類情報を判断します。これらのデータには、ソース ドメインの本物のデータ、ジェネレーター G によって生成された偽のデータ、およびターゲット ドメインによって生成された偽のデータが含まれます。

  • ソース ドメインで生成されたデータが偽であるかどうかを判断するには、ソース ドメインで生成されたデータのドメイン分類損失を通じて学習します。これらのデータのドメイン ラベルは 0 (偽のデータを示す) に設定されます。
  • ソース ドメイン内の実際のデータが本物であるかどうかを判断し、分類誤差を最小限に抑えます - ソース ドメイン内の実際のデータ ドメイン分類損失 (true、ドメイン ラベルは 1) + ソース ドメイン内の実際のデータ ラベル分類損失
  • ターゲット ドメインで生成されたデータが false かどうかを判断します。ターゲット ドメインで生成されたデータのドメイン分類損失 (false、ドメイン ラベルは 0)

ジェネレーターG

敵対的トレーニングを通じて識別器 D と競合し、特徴抽出器 F の学習を促進します。

  • 識別子 D を欺く: D にソース ドメインで生成されたデータが真であると誤って信じ込ませる - ソース ドメインで生成されたデータのドメイン分類損失 (混乱を避けるため、ドメイン ラベルは 1)
  • ソース ドメインで生成されたデータの分類エラーを最小限に抑える - ソース ドメインで生成されたデータ ラベルの分類損失

分類子 C

ソース ドメインの実データの分類誤差を最小限に抑える: ソース ドメインの実データの分類損失 (上記の識別器での分類損失は、分類器 C のパラメーターと共有されません。つまり、それらは分類器 C のパラメーターと共有されません)。同じ意味を持つ 2 つの分類ですが、パラメーターが異なります。識別子、ソース ドメインの実データ分類損失を計算するために C を使用する代わりに、独自に使用するための小さな分類器が識別子内にあります)

警告する

上記の特徴抽出器の紹介からわかるように、この論文の各コンポーネントの逆伝播は独立しています. 多くの損失には F が関係しますが、F の逆伝播は上記 3 つの損失にのみ関係します. その他の F のパラメータは影響しませんパラメータが更新されると更新されます。つまり、モデルを設計するコードを記述する場合、上記の 4 つのコンポーネントを個別に記述し、バックプロパゲーションを個別に実行する必要があります。

モデルの目標

F によって抽出された特徴マップにより、ソース ドメインとターゲット ドメインの類似データの特徴分布がより類似するため、テストの精度が向上します。F(src) と F(tar) は G に似ています。これは、パラメーター調整が F(src) と F(tar) を似せる方向に偏っていることを意味し、これがモデル最適化の目標です。

コード

import torch
import torch.nn as nn
from torch.autograd import Variable

"""
生成器G
"""

class _netG(nn.Module):
    def __init__(self, opt, nclasses):
        super(_netG, self).__init__()

        self.ndim = 2 * opt.ndf
        self.ngf = opt.ngf
        self.nz = opt.nz
        self.gpu = opt.gpu
        self.nclasses = nclasses

        # 定义生成器的主要神经网络结构
        self.main = nn.Sequential(
            nn.ConvTranspose2d(self.nz + self.ndim + nclasses + 1, self.ngf * 8, 2, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(self.ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        batchSize = input.size()[0]
        input = input.view(-1, self.ndim + self.nclasses + 1, 1, 1)
        noise = torch.FloatTensor(batchSize, self.nz, 1, 1).normal_(0, 1)
        if self.gpu >= 0:
            noise = noise.cuda()
        noisev = Variable(noise)
        output = self.main(torch.cat((input, noisev), 1))
        return output


"""
鉴别器D
"""

class _netD(nn.Module):
    def __init__(self, opt, nclasses):
        super(_netD, self).__init__()

        self.ndf = opt.ndf
        self.feature = nn.Sequential(
            nn.Conv2d(3, self.ndf, 3, 1, 1),
            nn.BatchNorm2d(self.ndf),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(self.ndf, self.ndf * 2, 3, 1, 1),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(self.ndf * 2, self.ndf * 4, 3, 1, 1),
            nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(self.ndf * 4, self.ndf * 2, 3, 1, 1),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(4, 4)
        )

        self.classifier_c = nn.Sequential(nn.Linear(self.ndf * 2, nclasses))
        self.classifier_s = nn.Sequential(
            nn.Linear(self.ndf * 2, 1),
            nn.Sigmoid())

    def forward(self, input):
        output = self.feature(input)
        output_s = self.classifier_s(output.view(-1, self.ndf * 2))
        output_s = output_s.view(-1)
        output_c = self.classifier_c(output.view(-1, self.ndf * 2))
        return output_s, output_c


"""
特征提取器F
"""

class _netF(nn.Module):
    def __init__(self, opt):
        super(_netF, self).__init__()

        self.ndf = opt.ndf
        self.feature = nn.Sequential(
            nn.Conv2d(3, self.ndf, 5, 1, 0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(self.ndf, self.ndf, 5, 1, 0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(self.ndf, self.ndf * 2, 5, 1, 0),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        output = self.feature(input)
        return output.view(-1, 2 * self.ndf)


"""
分类器C
"""

class _netC(nn.Module):
    def __init__(self, opt, nclasses):
        super(_netC, self).__init__()
        self.ndf = opt.ndf
        self.main = nn.Sequential(
            nn.Linear(2 * self.ndf, 2 * self.ndf),
            nn.ReLU(inplace=True),
            nn.Linear(2 * self.ndf, nclasses),
        )

    def forward(self, input):
        output = self.main(input)
        return output

おすすめ

転載: blog.csdn.net/weixin_51293984/article/details/135031135