foreword
In the previous article [Deep Domain Adaptation] 1. Detailed Explanation of DANN and Gradient Reversal Layer (GRL) , we mainly explained the basic principles of DANN’s network architecture and Gradient Reversal Layer (GRL). The next article In this article, we will mainly reproduce the migration training experiments of the MNIST and MNIST-M datasets in the DANN paper Unsupervised Domain Adaptation by Backpropagation .
1. Introduction to MNIST and MNIST-M
In order to use DANN to implement the migration training of the MNIST and MNIST-M datasets, we first need to obtain the MNIST and MNIST-M datasets. Among them, the MNIST dataset is easy to obtain, and the official website download link is: MNSIT . The files to be downloaded are the four blue files shown in the figure below.
Due to the deep integration of tensorflow and keras, we can use the related API of keras for the MNIST dataset, as follows:
from tensorflow.keras.datasets import mnist
# 导入MNIST数据集
(X_train,y_train),(X_test,y_test) = mnist.load_data()
The MNIST-M dataset consists of MNIST digits mixed with random color patches from the BSDS500 dataset. Then to generate the MNIST-M dataset, please download the BSDS500 dataset first. The official download address of the BSDS500 dataset is: BSDS500 .
The following is a screenshot of the official website of the BSDS500 dataset. Click the link in the blue box in the figure below to download the data.
After downloading the BSDS500 dataset, we must generate the MNIST-M dataset based on the MNIST and BSDS500 datasets. The script for generating the dataset is create_mnistm.py
as follows:
# -*- coding: utf-8 -*-
# @Time : 2021/7/24 下午1:50
# @Author : Dai Pu wei
# @Email : [email protected]
# @File : create_mnistm.py
# @Software: PyCharm
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tarfile
import numpy as np
import pickle as pkl
import skimage.io
import skimage.transform
from tensorflow.keras.datasets import mnist
rand = np.random.RandomState(42)
def compose_image(mnist_data, background_data):
"""
这是将MNIST数据和BSDS500数据进行融合成MNIST-M数据的函数
:param mnist_data: MNIST数据
:param background_data: BDSD500数据,作为背景图像
:return:
"""
# 随机融合MNIST数据和BSDS500数据
w, h, _ = background_data.shape
dw, dh, _ = mnist_data.shape
x = np.random.randint(0, w - dw)
y = np.random.randint(0, h - dh)
bg = background_data[x:x + dw, y:y + dh]
return np.abs(bg - mnist_data).astype(np.uint8)
def mnist_to_img(x):
"""
这是实现MNIST数据格式转换的函数,0/1数据位转化为RGB数据集
:param x: 0/1格式MNIST数据
:return:
"""
x = (x > 0).astype(np.float32)
d = x.reshape([28, 28, 1]) * 255
return np.concatenate([d, d, d], 2)
def create_mnistm(X,background_data):
"""
这是生成MNIST-M数据集的函数,MNIST-M数据集介绍可见:
http://jmlr.org/papers/volume17/15-239/15-239.pdf
:param X: MNIST数据集
:param background_data: BSDS500数据集,作为背景
:return:
"""
# 遍历所有MNIST数据集,生成MNIST-M数据集
X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8)
for i in range(X.shape[0]):
if i % 1000 == 0:
print('Processing example', i)
# 随机选择背景图像
bg_img = rand.choice(background_data)
# 0/1数据位格式MNIST数据转换为RGB格式
mnist_image = mnist_to_img(X[i])
# 将MNIST数据和BSDS500数据背景进行融合
mnist_image = compose_image(mnist_image, bg_img)
X_[i] = mnist_image
return X_
def run_main():
"""
这是主函数
"""
# 初始化路径
BST_PATH = os.path.abspath('./model_data/dataset/BSR_bsds500.tgz')
mnist_dir = os.path.abspath("model_data/dataset/MNIST")
mnistm_dir = os.path.abspath("model_data/dataset/MNIST_M")
# 导入MNIST数据集
(X_train,y_train),(X_test,y_test) = mnist.load_data()
# 加载BSDS500数据集
f = tarfile.open(BST_PATH)
train_files = []
for name in f.getnames():
if name.startswith('BSR/BSDS500/data/images/train/'):
train_files.append(name)
print('Loading BSR training images')
background_data = []
for name in train_files:
try:
fp = f.extractfile(name)
bg_img = skimage.io.imread(fp)
background_data.append(bg_img)
except:
continue
# 生成MNIST-M训练数据集和验证数据集
print('Building train set...')
train = create_mnistm(X_train,background_data)
print(np.shape(train))
print('Building validation set...')
valid = create_mnistm(X_test,background_data)
print(np.shape(valid))
# 将MNIST数据集转化为RGB格式
X_train = np.expand_dims(X_train,-1)
X_test = np.expand_dims(X_test,-1)
X_train = np.concatenate([X_train,X_train,X_train],axis=3)
X_test = np.concatenate([X_test,X_test,X_test],axis=3)
y_train = np.array(y_train).astype(np.int32)
y_test = np.array(y_test).astype(np.int32)
# 保存MNIST数据集为pkl文件
if not os.path.exists(mnist_dir):
os.mkdir(mnist_dir)
with open(os.path.join(mnist_dir, 'mnist_data.pkl'), 'wb') as f:
pkl.dump({
'train': X_train,
'train_label': y_train,
'val': X_test,
'val_label':y_test}, f, pkl.HIGHEST_PROTOCOL)
# 保存MNIST-M数据集为pkl文件
if not os.path.exists(mnistm_dir):
os.mkdir(mnistm_dir)
with open(os.path.join(mnistm_dir, 'mnist_m_data.pkl'), 'wb') as f:
pkl.dump({
'train': train,
'train_label':y_train,
'val': valid,
'val_label':y_test}, f, pkl.HIGHEST_PROTOCOL)
# 计算数据集平均值,用于数据标准化
print(np.shape(X_train))
print(np.shape(X_test))
print(np.shape(train))
print(np.shape(valid))
print(np.shape(y_train))
print(np.shape(y_test))
pixel_mean = np.vstack([X_train,train,X_test,valid]).mean((0,1,2))
print(np.shape(pixel_mean))
print(pixel_mean)
if __name__ == '__main__':
run_main()
2. Parameter configuration class config
Since the training process of the entire DANN-MNIST network involves many hyperparameters, for the convenience of programming the entire project, we use object-oriented thinking to place all hyperparameters into one class, namely the parameter configuration class config. The code for this parameter configuration class config is as follows:
# -*- coding: utf-8 -*-
# @Time : 2020/2/15 15:05
# @Author : Dai PuWei
# @Email : [email protected]
# @File : config.py
# @Software: PyCharm
import os
class config(object):
__defualt_dict__ = {
"pre_model_path":None,
"checkpoints_dir":os.path.abspath("./checkpoints"),
"logs_dir":os.path.abspath("./logs"),
"config_dir":os.path.abspath("./config"),
"image_input_shape":(28,28,3),
"image_size":28,
"init_learning_rate": 1e-2,
"momentum_rate":0.9,
"batch_size":256,
"epoch":500,
"pixel_mean":[45.652287,45.652287,45.652287],
}
def __init__(self,**kwargs):
"""
这是参数配置类的初始化函数
:param kwargs: 参数字典
"""
# 初始化相关配置参数
self.__dict__.update(self. __defualt_dict__)
# 根据相关传入参数进行参数更新
self.__dict__.update(kwargs)
if not os.path.exists(self.checkpoints_dir):
os.makedirs(self.checkpoints_dir)
if not os.path.exists(self.logs_dir):
os.makedirs(self.logs_dir)
if not os.path.exists(self.config_dir):
os.makedirs(self.config_dir)
def set(self,**kwargs):
"""
这是参数配置的设置函数
:param kwargs: 参数字典
:return:
"""
# 根据相关传入参数进行参数更新
self.__dict__.update(kwargs)
def save_config(self,time):
"""
这是保存参数配置类的函数
:param time: 时间点字符串
:return:
"""
# 更新相关目录
self.checkpoints_dir = os.path.join(self.checkpoints_dir,time)
self.logs_dir = os.path.join(self.logs_dir,time)
self.config_dir = os.path.join(self.config_dir,time)
if not os.path.exists(self.config_dir):
os.makedirs(self.config_dir)
if not os.path.exists(self.checkpoints_dir):
os.makedirs(self.checkpoints_dir)
if not os.path.exists(self.logs_dir):
os.makedirs(self.logs_dir)
config_txt_path = os.path.join(self.config_dir,"config.txt")
with open(config_txt_path,'a') as f:
for key,value in self.__dict__.items():
if key in ["checkpoints_dir","logs_dir","config_dir"]:
value = os.path.join(value,time)
s = key+": "+value+"\n"
f.write(s)
3. GradientReversalLayer
The more important module in DANN is the implementation of the Gradient Reversal Layer (GRL). GRL's tf2.x code is implemented as follows:
import tensorflow as tf
from tensorflow.keras.layers import Layer
@tf.custom_gradient
def gradient_reversal(x,alpha=1.0):
def grad(dy):
return -dy * alpha, None
return x, grad
class GradientReversalLayer(Layer):
def __init__(self,**kwargs):
"""
这是梯度反转层的初始化函数
:param kwargs: 参数字典
"""
super(GradientReversalLayer,self).__init__(kwargs)
def call(self, x,alpha=1.0):
"""
这是梯度反转层的初始化函数
:param x: 输入张量
:param alpha: alpha系数,默认为1
:return:
"""
return gradient_reversal(x,alpha)
In the above code, @ops.RegisterGradient(grad_name) modifies the _flip_gradients(op, grad) function, that is, to invert the gradient of this layer. At the same time, the gradient_override_map function is mainly used to solve the problem of using the self-defined function method to find the gradient. The parameter value of the gradient_override_map function is a dictionary. That is, the value in the dictionary means that the function represented by the value is used instead of the function represented by the key for gradient calculation.
4. DANN class code
In the DANN paper Unsupervised Domain Adaptation by Backpropagation , the network for the migration training experiment of the MNIST and MNIST-M datasets is given. The network architecture diagram is shown in the figure below.
Next, we will use tensorflow2.4.0 to build the entire DANN-MNIST network. The code of the DANN-MNIST network structure is as follows:
# -*- coding: utf-8 -*-
# @Time : 2020/2/14 20:27
# @Author : Dai PuWei
# @Email : [email protected]
# @File : MNIST2MNIST_M.py
# @Software: PyCharm
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Activation
def build_feature_extractor():
"""
这是特征提取子网络的构建函数
:param image_input: 图像输入张量
:param name: 输出特征名称
:return:
"""
model = tf.keras.Sequential([Conv2D(filters=32, kernel_size=5,strides=1),
#tf.keras.layers.BatchNormalization(),
Activation('relu'),
MaxPool2D(pool_size=(2, 2), strides=2),
Conv2D(filters=48, kernel_size=5,strides=1),
#tf.keras.layers.BatchNormalization(),
Activation('relu'),
MaxPool2D(pool_size=(2, 2), strides=2),
Flatten(),
])
return model
def build_image_classify_extractor():
"""
这是搭建图像分类器模型的函数
:param image_classify_feature: 图像分类特征张量
:return:
"""
model = tf.keras.Sequential([Dense(100),
#tf.keras.layers.BatchNormalization(),
Activation('relu'),
#tf.keras.layers.Dropout(0.5),
Dense(100,activation='relu'),
#tf.keras.layers.Dropout(0.5),
Dense(10,activation='softmax',name="image_cls_pred"),
])
return model
def build_domain_classify_extractor():
"""
这是搭建域分类器的函数
:param domain_classify_feature: 域分类特征张量
:return:
"""
# 搭建域分类器
model = tf.keras.Sequential([Dense(100),
#tf.keras.layers.BatchNormalization(),
Activation('relu'),
#tf.keras.layers.Dropout(0.5),
Dense(2, activation='softmax', name="domain_cls_pred")
])
return model
6. Experimental results
The following mainly includes the learning rate and gradient inversion layer parameters λ \lambda of the MNIST and MNIST-M data sets during the adaptive training processVisualization of λ , image classification loss, domain classification loss, image classification accuracy, domain classification accuracy, and model total loss for the training and validation sets.
The first is the hyperparameter learning rate and the gradient inversion layer parameter λ \lambdaData visualization of λ during training.
Next is the data visualization of the image classification accuracy and domain classification accuracy of the training data set and the verification data set during the training process, where the blue represents the training set and the red represents the verification set. The training accuracy is the statistical result on the source domain dataset, namely the MNIST dataset, and the verification accuracy is the statistical result on the target domain dataset, namely the MNIST-M dataset. Due to the high precision of the RTX30 graphics card, the training results of the adaptive training of the MNIST and MNIST-M data sets are stable at about 86%, which is much higher than the 81.49% precision of the original paper, which is reasonable.
Finally, the data visualization of the image classification loss and domain classification loss of the training data set and the verification data set during the training process, where the blue represents the training set and the red represents the verification set.
postscript
Initially implemented DANN using the tf1.x framework, but later found that due to the particularity of GRL, the adaptability between tf1. and GRL and complex network structures, such as YOLO v3, is low, so the code has been fully upgraded to tf2.x, Will also support pytorch if needed. The project code address of the original tf1.x is: tf1 branch of DANN-MNIST , and the project code address of tf2.x is as follows:
- DANN-MNIST's tf2 and master branches (tf2 and master branches merged)
- THEN-MNIST-tf2
Welcome to one-click three links on CSDN and Github