[Deep Domain Adaptation] 2. Using DANN to realize MNIST and MNIST-M data set migration training

foreword

In the previous article [Deep Domain Adaptation] 1. Detailed Explanation of DANN and Gradient Reversal Layer (GRL) , we mainly explained the basic principles of DANN’s network architecture and Gradient Reversal Layer (GRL). The next article In this article, we will mainly reproduce the migration training experiments of the MNIST and MNIST-M datasets in the DANN paper Unsupervised Domain Adaptation by Backpropagation .


1. Introduction to MNIST and MNIST-M

In order to use DANN to implement the migration training of the MNIST and MNIST-M datasets, we first need to obtain the MNIST and MNIST-M datasets. Among them, the MNIST dataset is easy to obtain, and the official website download link is: MNSIT . The files to be downloaded are the four blue files shown in the figure below.
insert image description here
Due to the deep integration of tensorflow and keras, we can use the related API of keras for the MNIST dataset, as follows:

from tensorflow.keras.datasets import mnist

# 导入MNIST数据集
(X_train,y_train),(X_test,y_test) = mnist.load_data()

The MNIST-M dataset consists of MNIST digits mixed with random color patches from the BSDS500 dataset. Then to generate the MNIST-M dataset, please download the BSDS500 dataset first. The official download address of the BSDS500 dataset is: BSDS500 .
The following is a screenshot of the official website of the BSDS500 dataset. Click the link in the blue box in the figure below to download the data.
insert image description here
After downloading the BSDS500 dataset, we must generate the MNIST-M dataset based on the MNIST and BSDS500 datasets. The script for generating the dataset is create_mnistm.pyas follows:

# -*- coding: utf-8 -*-
# @Time    : 2021/7/24 下午1:50
# @Author  : Dai Pu wei
# @Email   : [email protected]
# @File    : create_mnistm.py
# @Software: PyCharm

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tarfile
import numpy as np
import pickle as pkl
import skimage.io
import skimage.transform
from tensorflow.keras.datasets import mnist

rand = np.random.RandomState(42)

def compose_image(mnist_data, background_data):
    """
    这是将MNIST数据和BSDS500数据进行融合成MNIST-M数据的函数
    :param mnist_data: MNIST数据
    :param background_data: BDSD500数据,作为背景图像
    :return:
    """
    # 随机融合MNIST数据和BSDS500数据
    w, h, _ = background_data.shape
    dw, dh, _ = mnist_data.shape
    x = np.random.randint(0, w - dw)
    y = np.random.randint(0, h - dh)
    bg = background_data[x:x + dw, y:y + dh]
    return np.abs(bg - mnist_data).astype(np.uint8)

def mnist_to_img(x):
    """
    这是实现MNIST数据格式转换的函数,0/1数据位转化为RGB数据集
    :param x: 0/1格式MNIST数据
    :return:
    """
    x = (x > 0).astype(np.float32)
    d = x.reshape([28, 28, 1]) * 255
    return np.concatenate([d, d, d], 2)

def create_mnistm(X,background_data):
    """
    这是生成MNIST-M数据集的函数,MNIST-M数据集介绍可见:
    http://jmlr.org/papers/volume17/15-239/15-239.pdf
    :param X: MNIST数据集
    :param background_data: BSDS500数据集,作为背景
    :return:
    """
    # 遍历所有MNIST数据集,生成MNIST-M数据集
    X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8)
    for i in range(X.shape[0]):
        if i % 1000 == 0:
            print('Processing example', i)
        # 随机选择背景图像
        bg_img = rand.choice(background_data)
        # 0/1数据位格式MNIST数据转换为RGB格式
        mnist_image = mnist_to_img(X[i])
        # 将MNIST数据和BSDS500数据背景进行融合
        mnist_image = compose_image(mnist_image, bg_img)
        X_[i] = mnist_image
    return X_

def run_main():
    """
    这是主函数
    """
    # 初始化路径
    BST_PATH = os.path.abspath('./model_data/dataset/BSR_bsds500.tgz')
    mnist_dir = os.path.abspath("model_data/dataset/MNIST")
    mnistm_dir = os.path.abspath("model_data/dataset/MNIST_M")

    # 导入MNIST数据集
    (X_train,y_train),(X_test,y_test) = mnist.load_data()


    # 加载BSDS500数据集
    f = tarfile.open(BST_PATH)
    train_files = []
    for name in f.getnames():
        if name.startswith('BSR/BSDS500/data/images/train/'):
            train_files.append(name)
    print('Loading BSR training images')
    background_data = []
    for name in train_files:
        try:
            fp = f.extractfile(name)
            bg_img = skimage.io.imread(fp)
            background_data.append(bg_img)
        except:
            continue

    # 生成MNIST-M训练数据集和验证数据集
    print('Building train set...')
    train = create_mnistm(X_train,background_data)
    print(np.shape(train))
    print('Building validation set...')
    valid = create_mnistm(X_test,background_data)
    print(np.shape(valid))

    # 将MNIST数据集转化为RGB格式
    X_train = np.expand_dims(X_train,-1)
    X_test = np.expand_dims(X_test,-1)
    X_train = np.concatenate([X_train,X_train,X_train],axis=3)
    X_test = np.concatenate([X_test,X_test,X_test],axis=3)
    y_train = np.array(y_train).astype(np.int32)
    y_test = np.array(y_test).astype(np.int32)
    # 保存MNIST数据集为pkl文件
    if not os.path.exists(mnist_dir):
        os.mkdir(mnist_dir)
    with open(os.path.join(mnist_dir, 'mnist_data.pkl'), 'wb') as f:
        pkl.dump({
    
    'train': X_train,
                  'train_label': y_train,
                  'val': X_test,
                  'val_label':y_test}, f, pkl.HIGHEST_PROTOCOL)

    # 保存MNIST-M数据集为pkl文件
    if not os.path.exists(mnistm_dir):
        os.mkdir(mnistm_dir)
    with open(os.path.join(mnistm_dir, 'mnist_m_data.pkl'), 'wb') as f:
        pkl.dump({
    
    'train': train,
                  'train_label':y_train,
                  'val': valid,
                  'val_label':y_test}, f, pkl.HIGHEST_PROTOCOL)

    # 计算数据集平均值,用于数据标准化
    print(np.shape(X_train))
    print(np.shape(X_test))
    print(np.shape(train))
    print(np.shape(valid))
    print(np.shape(y_train))
    print(np.shape(y_test))
    pixel_mean = np.vstack([X_train,train,X_test,valid]).mean((0,1,2))
    print(np.shape(pixel_mean))
    print(pixel_mean)

if __name__ == '__main__':
    run_main()


2. Parameter configuration class config

Since the training process of the entire DANN-MNIST network involves many hyperparameters, for the convenience of programming the entire project, we use object-oriented thinking to place all hyperparameters into one class, namely the parameter configuration class config. The code for this parameter configuration class config is as follows:

# -*- coding: utf-8 -*-
# @Time    : 2020/2/15 15:05
# @Author  : Dai PuWei
# @Email   : [email protected]
# @File    : config.py
# @Software: PyCharm

import os

class config(object):

    __defualt_dict__ = {
    
    
        "pre_model_path":None,
        "checkpoints_dir":os.path.abspath("./checkpoints"),
        "logs_dir":os.path.abspath("./logs"),
        "config_dir":os.path.abspath("./config"),
        "image_input_shape":(28,28,3),
        "image_size":28,
        "init_learning_rate": 1e-2,
        "momentum_rate":0.9,
        "batch_size":256,
        "epoch":500,
        "pixel_mean":[45.652287,45.652287,45.652287],
    }

    def __init__(self,**kwargs):
        """
        这是参数配置类的初始化函数
        :param kwargs: 参数字典
        """
        # 初始化相关配置参数
        self.__dict__.update(self. __defualt_dict__)
        # 根据相关传入参数进行参数更新
        self.__dict__.update(kwargs)

        if not os.path.exists(self.checkpoints_dir):
            os.makedirs(self.checkpoints_dir)
        if not os.path.exists(self.logs_dir):
            os.makedirs(self.logs_dir)
        if not os.path.exists(self.config_dir):
            os.makedirs(self.config_dir)

    def set(self,**kwargs):
        """
        这是参数配置的设置函数
        :param kwargs: 参数字典
        :return:
        """
        # 根据相关传入参数进行参数更新
        self.__dict__.update(kwargs)


    def save_config(self,time):
        """
        这是保存参数配置类的函数
        :param time: 时间点字符串
        :return:
        """
        # 更新相关目录
        self.checkpoints_dir = os.path.join(self.checkpoints_dir,time)
        self.logs_dir = os.path.join(self.logs_dir,time)
        self.config_dir = os.path.join(self.config_dir,time)

        if not os.path.exists(self.config_dir):
            os.makedirs(self.config_dir)
        if not os.path.exists(self.checkpoints_dir):
            os.makedirs(self.checkpoints_dir)
        if not os.path.exists(self.logs_dir):
            os.makedirs(self.logs_dir)

        config_txt_path = os.path.join(self.config_dir,"config.txt")
        with open(config_txt_path,'a') as f:
            for key,value in self.__dict__.items():
                if key in ["checkpoints_dir","logs_dir","config_dir"]:
                    value = os.path.join(value,time)
                s = key+": "+value+"\n"
                f.write(s)


3. GradientReversalLayer

The more important module in DANN is the implementation of the Gradient Reversal Layer (GRL). GRL's tf2.x code is implemented as follows:

import tensorflow as tf
from tensorflow.keras.layers import Layer

@tf.custom_gradient
def gradient_reversal(x,alpha=1.0):
	def grad(dy):
		return -dy * alpha, None
	return x, grad

class GradientReversalLayer(Layer):

	def __init__(self,**kwargs):
		"""
		这是梯度反转层的初始化函数
		:param kwargs: 参数字典
		"""
		super(GradientReversalLayer,self).__init__(kwargs)

	def call(self, x,alpha=1.0):
		"""
		这是梯度反转层的初始化函数
		:param x: 输入张量
		:param alpha: alpha系数,默认为1
		:return:
		"""
		return gradient_reversal(x,alpha)

In the above code, @ops.RegisterGradient(grad_name) modifies the _flip_gradients(op, grad) function, that is, to invert the gradient of this layer. At the same time, the gradient_override_map function is mainly used to solve the problem of using the self-defined function method to find the gradient. The parameter value of the gradient_override_map function is a dictionary. That is, the value in the dictionary means that the function represented by the value is used instead of the function represented by the key for gradient calculation.


4. DANN class code

In the DANN paper Unsupervised Domain Adaptation by Backpropagation , the network for the migration training experiment of the MNIST and MNIST-M datasets is given. The network architecture diagram is shown in the figure below.
insert image description here
Next, we will use tensorflow2.4.0 to build the entire DANN-MNIST network. The code of the DANN-MNIST network structure is as follows:

# -*- coding: utf-8 -*-
# @Time    : 2020/2/14 20:27
# @Author  : Dai PuWei
# @Email   : [email protected]
# @File    : MNIST2MNIST_M.py
# @Software: PyCharm

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Activation

def build_feature_extractor():
    """
    这是特征提取子网络的构建函数
    :param image_input: 图像输入张量
    :param name: 输出特征名称
    :return:
    """
    model = tf.keras.Sequential([Conv2D(filters=32, kernel_size=5,strides=1),
                                 #tf.keras.layers.BatchNormalization(),
                                 Activation('relu'),
                                 MaxPool2D(pool_size=(2, 2), strides=2),
                                 Conv2D(filters=48, kernel_size=5,strides=1),
                                 #tf.keras.layers.BatchNormalization(),
                                 Activation('relu'),
                                 MaxPool2D(pool_size=(2, 2), strides=2),
                                 Flatten(),
    ])
    return model

def build_image_classify_extractor():
    """
    这是搭建图像分类器模型的函数
    :param image_classify_feature: 图像分类特征张量
    :return:
    """
    model = tf.keras.Sequential([Dense(100),
                                 #tf.keras.layers.BatchNormalization(),
                                 Activation('relu'),
                                 #tf.keras.layers.Dropout(0.5),
                                 Dense(100,activation='relu'),
                                 #tf.keras.layers.Dropout(0.5),
                                 Dense(10,activation='softmax',name="image_cls_pred"),
    ])
    return model

def build_domain_classify_extractor():
    """
    这是搭建域分类器的函数
    :param domain_classify_feature: 域分类特征张量
    :return:
    """
    # 搭建域分类器
    model = tf.keras.Sequential([Dense(100),
                                 #tf.keras.layers.BatchNormalization(),
                                 Activation('relu'),
                                 #tf.keras.layers.Dropout(0.5),
                                 Dense(2, activation='softmax', name="domain_cls_pred")
    ])
    return model

6. Experimental results

The following mainly includes the learning rate and gradient inversion layer parameters λ \lambda of the MNIST and MNIST-M data sets during the adaptive training processVisualization of λ , image classification loss, domain classification loss, image classification accuracy, domain classification accuracy, and model total loss for the training and validation sets.

The first is the hyperparameter learning rate and the gradient inversion layer parameter λ \lambdaData visualization of λ during training.
insert image description here

Next is the data visualization of the image classification accuracy and domain classification accuracy of the training data set and the verification data set during the training process, where the blue represents the training set and the red represents the verification set. The training accuracy is the statistical result on the source domain dataset, namely the MNIST dataset, and the verification accuracy is the statistical result on the target domain dataset, namely the MNIST-M dataset. Due to the high precision of the RTX30 graphics card, the training results of the adaptive training of the MNIST and MNIST-M data sets are stable at about 86%, which is much higher than the 81.49% precision of the original paper, which is reasonable.
insert image description here

Finally, the data visualization of the image classification loss and domain classification loss of the training data set and the verification data set during the training process, where the blue represents the training set and the red represents the verification set.
insert image description here

postscript

Initially implemented DANN using the tf1.x framework, but later found that due to the particularity of GRL, the adaptability between tf1. and GRL and complex network structures, such as YOLO v3, is low, so the code has been fully upgraded to tf2.x, Will also support pytorch if needed. The project code address of the original tf1.x is: tf1 branch of DANN-MNIST , and the project code address of tf2.x is as follows:

Welcome to one-click three links on CSDN and Github

Guess you like

Origin blog.csdn.net/qq_30091945/article/details/104495520