I. Overview
The code comes from: https://github.com/jindongwang/transferlearning , you can go to github to download the code, the location of the code involved in this article is: Code->DeepDA. The theoretical basis can be found in: [Transfer Learning] Domain Adaptation
The overall network structure is as follows: it can be regarded as a classification network (such as Resnet50) + a fc_adapt module. The loss function is that the original cross-entropy loss function is followed by a loss function that measures the distance between the source domain and the target domain, generally MMD, and the hyperparameters are used to control the weight of this loss function.
2. Code Analysis
1. data loader
Go to the main function, you can see that the code related to dataloader is:
source_loader, target_train_loader, target_test_loader, n_class = load_data(args)
After the jump, you can see that the data set in the load_data(.) function is divided into the following three parts, corresponding to the source domain, target domain training set, and target domain test set.
source_loader,target_train_loader,target_test_loader
What actually works is load_data in data_loader, which rewrites the DataSet class and calls dataloader
data = datasets.ImageFolder(root=data_folder, transform=transform['train' if train else 'test'])
data_loader = get_data_loader(data, batch_size=batch_size,
shuffle=True if train else False,
num_workers=num_workers, **kwargs, drop_last=True if train else False)
n_class = len(data.classes)
2.model
In the main function, the code related to the model is:
model = get_model(args)
def get_model(args):
model = models.TransferNet(
args.n_class, transfer_loss=args.transfer_loss, base_net=args.backbone, max_iter=args.max_iter, use_bottleneck=args.use_bottleneck).to(args.device)
return model
This function actually extracts the network model from the TransferNet class in the models file, and continues to TransferNet . This class has the following parameters: number of categories, backbone network type, transfer_loss type, and some parameters for adjusting the backbone network.
class TransferNet(nn.Module):
def __init__(self, num_class, base_net='resnet50',
transfer_loss='mmd', use_bottleneck=True,
bottleneck_width=256, max_iter=1000, **kwargs):
Then look at the forward transfer function of the network, which can basically be classified into the following parts:
①Backbone network extraction
source = self.base_network(source)
target = self.base_network(target)
② Source Domain Classification
source_clf = self.classifier_layer(source)
The classifier in the code is a fully connected layer , and the corresponding parameters are the number of hidden layer channels and the number of output categories
self.classifier_layer = nn.Linear(feature_dim, num_class)
③ Source Domain Classification Loss Function Calculation
clf_loss = self.criterion(source_clf, source_label)
A cross-entropy loss function is used in the code
self.criterion = torch.nn.CrossEntropyLoss()
④ transfer learning
This section is the biggest difference between domain adaptive transfer learning and the traditional classification network. In addition to the traditional cls_loss, the network also calculates the transfer_loss.
The code provides three methods of lmmd, daan, and bnm. The codes are basically the same. Here, lmmd is selected for analysis. The lmmd code is as follows:
if self.transfer_loss == "lmmd":
kwargs['source_label'] = source_label
target_clf = self.classifier_layer(target)
kwargs['target_logits'] = torch.nn.functional.softmax(target_clf, dim=1)
The main function of this code is to obtain source_label from the parameter list , and use the same classifier above to classify the target domain features extracted by the backbone network (using softmax for classification, the result is recorded as target_logits ), and then the source domain features and The target domain features and parameters will be sent to the adapt_loss(.) module to calculate the transfer_loss .
At the same time, we know from the follow-up code that if you use the simplest mmd, you don't need to go through this step .
transfer_loss = self.adapt_loss(source, target, **kwargs)
self.adapt_loss = TransferLoss(**transfer_loss_args)
The TransferLoss class is a transfer_loss extraction module that provides 6 different loss functions. Here we take mmd as an example.
if loss_type == "mmd":
self.loss_func = MMDLoss(**kwargs)
The complete code of MMDLoss is as follows:
import torch
import torch.nn as nn
class MMDLoss(nn.Module):
def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None, **kwargs):
super(MMDLoss, self).__init__()
self.kernel_num = kernel_num
self.kernel_mul = kernel_mul
self.fix_sigma = None
self.kernel_type = kernel_type
def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma):
n_samples = int(source.size()[0]) + int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(
int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(
int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)**2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i)
for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
for bandwidth_temp in bandwidth_list]
return sum(kernel_val)
def linear_mmd2(self, f_of_X, f_of_Y):
loss = 0.0
delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
loss = delta.dot(delta.T)
return loss
def forward(self, source, target):
if self.kernel_type == 'linear':
return self.linear_mmd2(source, target)
elif self.kernel_type == 'rbf':
batch_size = int(source.size()[0])
kernels = self.guassian_kernel(
source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
XX = torch.mean(kernels[:batch_size, :batch_size])
YY = torch.mean(kernels[batch_size:, batch_size:])
XY = torch.mean(kernels[:batch_size, batch_size:])
YX = torch.mean(kernels[batch_size:, :batch_size])
loss = torch.mean(XX + YY - XY - YX)
return loss
There are a lot of mathematical transformations in it, so I won't go into it here. The role of its forward function is to calculate the gap between the source domain and the target domain and return the loss .
⑤ return function
The forward transfer function of the model will finally return two parameters: clf_loss, transfer_loss, these two parameters will be used for the subsequent reverse transfer.
3. Training
During the training process, the source domain image and label are extracted from source_loader , and the target domain image is extracted from target_train_loader (the label of the target domain is not required).
Then send these three data into the model to get clf_loss and transfer_loss. Finally, the final loss function loss can be obtained by adding transfer_loss×weight coefficient lambda to clf_loss.
clf_loss, transfer_loss = model(data_source, data_target, label_source)
loss = clf_loss + args.transfer_loss_weight * transfer_loss
Then there is an automatic reverse pass:
optimizer.zero_grad()
loss.backward()
optimizer.step()
4. Test
The testing process is similar to the above training, mainly because the forward pass is no longer required. The predict function in the model is used instead of the default forward transfer function
s_output = model.predict(data)
This function does not have the adapt_loss module compared to the previous forward pass function:
def predict(self, x):
features = self.base_network(x)
x = self.bottleneck_layer(features)
clf = self.classifier_layer(x)
return clf