[Transfer Learning] Domain Adaptive Code Analysis

I. Overview

        The code comes from: https://github.com/jindongwang/transferlearning , you can go to github to download the code, the location of the code involved in this article is: Code->DeepDA. The theoretical basis can be found in: [Transfer Learning] Domain Adaptation

        The overall network structure is as follows: it can be regarded as a classification network (such as Resnet50) + a fc_adapt module. The loss function is L=L_c(x_i,y_i)+\lambda Distance(D_s,D_t)that the original cross-entropy loss function is followed by a loss function that measures the distance between the source domain and the target domain, generally MMD, and the hyperparameters are used \lambdato control the weight of this loss function.

 2. Code Analysis

        1. data loader

        Go to the main function, you can see that the code related to dataloader is:

source_loader, target_train_loader, target_test_loader, n_class = load_data(args)

        After the jump, you can see that the data set in the load_data(.) function is divided into the following three parts, corresponding to the source domain, target domain training set, and target domain test set.

source_loader,target_train_loader,target_test_loader

        What actually works is load_data in data_loader, which rewrites the DataSet class and calls dataloader

data = datasets.ImageFolder(root=data_folder, transform=transform['train' if train else 'test'])
    data_loader = get_data_loader(data, batch_size=batch_size, 
                                shuffle=True if train else False, 
                                num_workers=num_workers, **kwargs, drop_last=True if train else False)
    n_class = len(data.classes)

        2.model

        In the main function, the code related to the model is:

model = get_model(args)
def get_model(args):
    model = models.TransferNet(
        args.n_class, transfer_loss=args.transfer_loss, base_net=args.backbone, max_iter=args.max_iter, use_bottleneck=args.use_bottleneck).to(args.device)
    return model

        This function actually extracts the network model from the TransferNet class in the models file, and continues to TransferNet . This class has the following parameters: number of categories, backbone network type, transfer_loss type, and some parameters for adjusting the backbone network.

class TransferNet(nn.Module):
    def __init__(self, num_class, base_net='resnet50', 
        transfer_loss='mmd', use_bottleneck=True, 
        bottleneck_width=256, max_iter=1000, **kwargs):

        Then look at the forward transfer function of the network, which can basically be classified into the following parts:

                ①Backbone network extraction

source = self.base_network(source)
target = self.base_network(target)

                ② Source Domain Classification

source_clf = self.classifier_layer(source)

                The classifier in the code is a fully connected layer , and the corresponding parameters are the number of hidden layer channels and the number of output categories

self.classifier_layer = nn.Linear(feature_dim, num_class)

                ③ Source Domain Classification Loss Function Calculation

clf_loss = self.criterion(source_clf, source_label)

                A cross-entropy loss function is used in the code

self.criterion = torch.nn.CrossEntropyLoss()

                ④ transfer learning

                This section is the biggest difference between domain adaptive transfer learning and the traditional classification network. In addition to the traditional cls_loss, the network also calculates the transfer_loss.

                The code provides three methods of lmmd, daan, and bnm. The codes are basically the same. Here, lmmd is selected for analysis. The lmmd code is as follows:

if self.transfer_loss == "lmmd":
    kwargs['source_label'] = source_label
    target_clf = self.classifier_layer(target)
    kwargs['target_logits'] = torch.nn.functional.softmax(target_clf, dim=1)

The main function of this code is to obtain source_label                 from the parameter list , and use the same classifier above to classify the target domain features extracted by the backbone network (using softmax for classification, the result is recorded as target_logits ), and then the source domain features and The target domain features and parameters will be sent to the adapt_loss(.) module to calculate the transfer_loss .

                At the same time, we know from the follow-up code that if you use the simplest mmd, you don't need to go through this step .

transfer_loss = self.adapt_loss(source, target, **kwargs)
self.adapt_loss = TransferLoss(**transfer_loss_args)

                The TransferLoss class is a transfer_loss extraction module that provides 6 different loss functions. Here we take mmd as an example.

if loss_type == "mmd":
    self.loss_func = MMDLoss(**kwargs)

                The complete code of MMDLoss is as follows:

import torch
import torch.nn as nn

class MMDLoss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None, **kwargs):
        super(MMDLoss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i)
                          for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
                      for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(
                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            XX = torch.mean(kernels[:batch_size, :batch_size])
            YY = torch.mean(kernels[batch_size:, batch_size:])
            XY = torch.mean(kernels[:batch_size, batch_size:])
            YX = torch.mean(kernels[batch_size:, :batch_size])
            loss = torch.mean(XX + YY - XY - YX)
            return loss

                There are a lot of mathematical transformations in it, so I won't go into it here. The role of its forward function is to calculate the gap between the source domain and the target domain and return the loss .

                ⑤ return function

                The forward transfer function of the model will finally return two parameters: clf_loss, transfer_loss, these two parameters will be used for the subsequent reverse transfer.

        3. Training

        During the training process, the source domain image and label are extracted from source_loader , and the target domain image is extracted from target_train_loader (the label of the target domain is not required).

        Then send these three data into the model to get clf_loss and transfer_loss. Finally, the final loss function loss can be obtained by adding transfer_loss×weight coefficient lambda to clf_loss.

clf_loss, transfer_loss = model(data_source, data_target, label_source)
loss = clf_loss + args.transfer_loss_weight * transfer_loss

        Then there is an automatic reverse pass:

optimizer.zero_grad()
loss.backward()
optimizer.step()

        4. Test

        The testing process is similar to the above training, mainly because the forward pass is no longer required. The predict function in the model is used instead of the default forward transfer function

s_output = model.predict(data)

        This function does not have the adapt_loss module compared to the previous forward pass function:

def predict(self, x):
    features = self.base_network(x)
    x = self.bottleneck_layer(features)
    clf = self.classifier_layer(x)
    return clf

Guess you like

Origin blog.csdn.net/weixin_37878740/article/details/131182026