[Transfer Learning Paper 5] Generate To Adapt Aligning Domains using Generative Adversarial Networks paper principle and reproduction work

Generate To Adapt: ​​Aligning Domains using Generative Adversarial Networks Generate To Adapt: ​​Align Domains using Generative Adversarial Networks

Preface

  • I haven’t updated for a long time, so I started recording it so I can push myself.
  • Record the explanation of the principles and reproduction work of related transfer learning papers during my preparation for graduate school.

question

Article introduction

  • This article was published in CVPR in 2018 and was written by Swami Sankaranarayanan, Yogesh Balaji, Carlos D. Castillo, Rama Chellappa.
  • Joint feature space: The feature representation shared between the source domain and the target domain learned by the model has a better alignment between the source domain and the target domain for better transfer.
  • The main contribution of this article is to propose an unsupervised domain adaptation method for adversarial image generation that can directly learn the joint feature space. The uniqueness of this method compared with previous methods is that it uses both generative and discriminative ideas, and uses the adversarial process of image generation to learn a feature space that minimizes the feature distribution of the source domain and the target domain.
  • An adversarial image generation method is proposed to directly learn shared feature embeddings using labeled data from the source domain and unlabeled data from the target.
  • Generator G is responsible for generating "fake" data similar to the source domain data, hoping that these data will be similar to the real source domain data. The task of the discriminator D is to identify whether the generated data is real or fake. The goal of the feature extractor F is to learn a feature representation that can adapt to both the source domain and the target domain, so that the features extracted by F can be better used for the task, and also make the generated data closer to the real data, so that the discriminator It's difficult to distinguish them.

Model structure

Insert image description here

Overview of the role of components

  • In the above figure, F is the feature extractor, C is the classifier, G is the generator, and D is the discriminator. F is used to extract source domain and target domain image features, C is used to predict source domain data labels, G is used to generate confusing generated data, and D is used to identify real data and generated data.
  • The model is divided into two branches:
    • Label prediction network (processing of source domain data) : The source domain data is characterized by F, and then input into C for classification, and the cross-entropy classification loss is calculated.
    • Generative adversarial network (transfer learning and domain adaptation) : G is used to generate the discrimination task of confusing data D. The G generator processes the extracted source domain data features (transposed convolution, label is 0) to generate data similar but not identical to the source domain . These generated data are used as confusion data in transfer learning. G calculates the domain classification loss and label classification loss after these source domain generated data are put into D; D is used for identification, distinguishing real data and generated data. It calculates three losses. First, the source domain real data is directly put into D, and the calculation Domain classification loss and label classification loss, and then the target domain generated data extracted by G is put into the domain classification loss calculated by D. Finally, the source domain generated data extracted by G is put into the domain classification loss calculated by D.

Feature extractor F

  • Minimize the classification loss of real data in the source domain: enable F to learn a feature representation more suitable for real data in the source domain

  • Minimize the classification loss of source domain generated data: enable F to understand and learn the characteristics of source domain generated data

  • Let D mistakenly believe that the data generated in the target domain is true: make the generated data more suitable for the target domain and confuse the discriminator D

  • Source domain real data label classification loss:
    Insert image description here

  • Source domain generated data label classification loss:
    Insert image description here

  • The target domain generates the data domain classification loss (for confusion, so the domain label is 1):Insert image description here

  • The design and optimization purpose of these loss functions is to adjust the parameters of the feature extractor F so that it can better adapt to the source domain data and target domain data, thereby achieving better domain adaptation and transfer learning effects. Only by directly using the source domain The source domain real data domain classification loss of domain data is independent of F.

discriminatorD

Determine the authenticity and classification information of different types of data through training. These data include real data in the source domain, fake data generated by the generator G, and fake data generated by the target domain.

  • To determine whether the source domain generated data is fake, learn through the domain classification loss of the source domain generated data. The domain labels of these data are set to 0 (indicating fake data).
  • Determine whether the real data in the source domain is real and minimize the classification error - the real data domain classification loss in the source domain (true, domain label is 1) + the real data label classification loss in the source domain
  • Determine whether the target domain generated data is false - the target domain generated data domain classification loss (false, domain label is 0)

GeneratorG

Compete with the discriminator D through adversarial training and promote the learning of the feature extractor F.

  • Deceiving the discriminator D: making D mistakenly believe that the source domain generated data is true - the source domain generated data domain classification loss (for confusion, so the domain label is 1)
  • Minimizing the classification error of source domain generated data - source domain generated data label classification loss

Classifier C

Minimizing the classification error of the real data in the source domain: the classification loss of the real data in the source domain (the classification loss in the discriminator mentioned above is not shared with the parameters of the classifier C, that is to say, they are two classifications with the same meaning but different parameters. discriminator, there is a small classifier in the discriminator for its own use instead of using C to calculate the source domain real data classification loss)

warn

As can be seen from the above introduction to the feature extractor, the backpropagation of each component in this paper is independent. Many losses involve F, but the backpropagation of F is only related to the above three losses. Others The parameters of F will not be updated when the parameters are updated. In other words, when we write code to design the model, we need to write the above four components separately and perform backpropagation separately.

model goals

The feature map extracted by F makes the feature distribution of similar data in the source domain and target domain more similar, thereby improving the test accuracy. F(src) and F(tar) are similar to G, which means that parameter adjustment is biased towards making F(src) and F(tar) similar, which is the goal of model optimization.

code

import torch
import torch.nn as nn
from torch.autograd import Variable

"""
生成器G
"""

class _netG(nn.Module):
    def __init__(self, opt, nclasses):
        super(_netG, self).__init__()

        self.ndim = 2 * opt.ndf
        self.ngf = opt.ngf
        self.nz = opt.nz
        self.gpu = opt.gpu
        self.nclasses = nclasses

        # 定义生成器的主要神经网络结构
        self.main = nn.Sequential(
            nn.ConvTranspose2d(self.nz + self.ndim + nclasses + 1, self.ngf * 8, 2, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(self.ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        batchSize = input.size()[0]
        input = input.view(-1, self.ndim + self.nclasses + 1, 1, 1)
        noise = torch.FloatTensor(batchSize, self.nz, 1, 1).normal_(0, 1)
        if self.gpu >= 0:
            noise = noise.cuda()
        noisev = Variable(noise)
        output = self.main(torch.cat((input, noisev), 1))
        return output


"""
鉴别器D
"""

class _netD(nn.Module):
    def __init__(self, opt, nclasses):
        super(_netD, self).__init__()

        self.ndf = opt.ndf
        self.feature = nn.Sequential(
            nn.Conv2d(3, self.ndf, 3, 1, 1),
            nn.BatchNorm2d(self.ndf),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(self.ndf, self.ndf * 2, 3, 1, 1),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(self.ndf * 2, self.ndf * 4, 3, 1, 1),
            nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(self.ndf * 4, self.ndf * 2, 3, 1, 1),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(4, 4)
        )

        self.classifier_c = nn.Sequential(nn.Linear(self.ndf * 2, nclasses))
        self.classifier_s = nn.Sequential(
            nn.Linear(self.ndf * 2, 1),
            nn.Sigmoid())

    def forward(self, input):
        output = self.feature(input)
        output_s = self.classifier_s(output.view(-1, self.ndf * 2))
        output_s = output_s.view(-1)
        output_c = self.classifier_c(output.view(-1, self.ndf * 2))
        return output_s, output_c


"""
特征提取器F
"""

class _netF(nn.Module):
    def __init__(self, opt):
        super(_netF, self).__init__()

        self.ndf = opt.ndf
        self.feature = nn.Sequential(
            nn.Conv2d(3, self.ndf, 5, 1, 0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(self.ndf, self.ndf, 5, 1, 0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(self.ndf, self.ndf * 2, 5, 1, 0),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        output = self.feature(input)
        return output.view(-1, 2 * self.ndf)


"""
分类器C
"""

class _netC(nn.Module):
    def __init__(self, opt, nclasses):
        super(_netC, self).__init__()
        self.ndf = opt.ndf
        self.main = nn.Sequential(
            nn.Linear(2 * self.ndf, 2 * self.ndf),
            nn.ReLU(inplace=True),
            nn.Linear(2 * self.ndf, nclasses),
        )

    def forward(self, input):
        output = self.main(input)
        return output

Guess you like

Origin blog.csdn.net/weixin_51293984/article/details/135031135