对抗神经网络学习(十)——attentiveGAN实现影像去雨滴的过程(tensorflow实现)

一、背景

attentiveGAN是Rui Qian等人于17年11月份提出的一种模型。《Attentive Generative Adversarial Network for Raindrop Removal  from A Single Image》在generator网络中引入了attention map,提高了影像中雨滴的去除效果。

本实验主要参考代码[2],进行了简单改进,用较短的代码实现该过程。

[1]文章链接: https://arxiv.org/pdf/1711.10098.pdf

[2]attentive-GAN-derainnet: https://github.com/MaybeShewill-CV/attentive-gan-derainnet

二、attentiveGAN原理

attentiveGAN网上的介绍并不多,这里推荐一篇:

[3]效果惊艳!北大团队提出Attentive GAN去除图像中雨滴

下面我们来看一看原文中作者的一些描述。首先作者表示图像去雨滴的难度在于两点:

The problem is intractable, since first the regions occluded by raindrops are not given. Second, the information about the background scene of the occluded regions is completely lost for most part.(一是因为雨滴的覆盖范围没有给出,二是因为雨滴覆盖区域的背景信息损失太多。)

之后作者表示文章的最大贡献在于将attention map引入模型中,这样就能使得生成器generator能够雨滴所在区域的结构,判别器discriminator能够评估局部连续性。

Our main idea is to inject visual attention into both the generative and discriminative networks. During the training, our visual attention learns about raindrop regions and their surroundings. Hence, by injecting this information, the generative network will pay more attention to the raindrop regions and the surrounding structures, and the discriminative network will be able to assess the local consistency of the restored regions. This injection of visual attention to both generative and discriminative networks is the main contribution of this paper.

那么,如何生成attention map呢?作者利用ResNet和LSTM,以及少量卷积层来生成,并将这个结构命名为attentive-recurrent network. 同时,输入影像可以表示为三部分:

                                                           I = (1-M)\odot B+R

即输入影像(I)可以看作去掉雨滴掩模(M)的背景(B)和雨滴效应(R)的混合。且我们关注的信息主要是背景区域,前景即雨滴区域往往是模糊的。因此我们的目的就是还原输入影像(I)的背景(B),而雨滴掩模(M)我们可以用attention map来生成。

关于网络结构,先给出示意图:

可以看到,在generator中,作者先使用recurrent network生成attention map,每一次生成attention map都使用了5个ResNet,和LSTM以及1个卷积层。然后作者使用了类似自编码器的结构Contextual Autoencoder,其中有16个conv-relu层。对于判别器discriminator来说,如果图像中有任何不连续( inconsistency)的地方,都能够很方便的用于判断真假,因此,作者采用局部判别器(local discrimnator)来处理。将输入数据放入到CNN中提取特征,并引入attention map以引导discriminator,最后再引入全连接层以判断图像的真假。判别器一共包含7个卷积层和一个全连接层。

最后作者将自己的模型结果与其他模型进行了简单对比,效果非常好:

三、attentive GAN所有文件介绍

接下来的三、四、五节都是关于模型的实现过程,我主要参考了代码[2],并做了少量修改,当然作者也提供了原代码,不过是pytorch版本的,有兴趣的话可以参考[4],下面来逐一介绍所有文件。

[4]https://github.com/rui1996/DeRaindrop

1. 所有文件结构

所有文件的结构为,其中需要自己准备的文件用#####标记了出来:

-- attentive_GAN_model                          # 文件夹中存放了attentiveGAN的相关文件
            |------ attentive_GAN_net.py
            |------ cnn_basenet.py
            |------ derain_drop_net.py
            |------ discriminative_net.py
            |------ tf_ssim.py
            |------ vgg16.py
-- data_provider                                # 文件夹中存放了读取数据的文件
            |------ data_provider.py
-- config                                       # 文件夹中存放了配置文件
            |------ global_config.py
-- data2txt.py                                  # 将数据写入到train.txt文件中  
-- train_model.py                               # 训练文件 
-- test_model.py                                # 测试文件
-- data
    |------ test_data                           # 测试数据
                |------ 0_rain.png
                |------ 1_rain.png
                |------ ......
    |------ training_data                       # 训练数据 
                |------ data                    ##### 有雨滴的训练图像
                        |------ 0_rain.png
                        |------ 1_rain.png
                        |------ ......
                |------ gt                      ##### 清晰的训练图像
                        |------ 0_clear.png
                        |------ 1_clear.png
                        |------ ......
                |------ train.txt               ##### 准备好数据后,用data2txt生成的文件 
    |------ vgg16.npy                           ##### 需要自己手动下载 

2. 数据准备

这里我们需要准备的东西有3个:

(1)vgg16.npy文件

原链接[3]中作者并没有提供vgg16.npy的文件,所以我就自己找了一份,然后传到了自己的百度云上,当然如果自己有的话也可以直接拿来用。下面直接给出链接:

百度云地址:https://pan.baidu.com/s/13lZ1PEVTvpBt5l1-7qZaPQ

提取码:hqnr

下载好之后,放到路径'./data/'下即可。

(2)训练及测试数据

如果你不想做训练的话,可以自己随意找几张照片进行测试。不过为了这个小实验的完整性,作者给出了原数据的下载地址,可参考:

https://drive.google.com/open?id=1e7R76s6vwUJxILOcAsthgDLPSnOrQ49K

训练集一共861个影像对,测试集一共239个影像对。上述链接需要翻墙打开,打开后直接下载就可以了:

上述数据集翻墙下载起来比较慢,我将这部分数据下载好放到了自己的百度云上,下面给出链接,不方便在google上下载的话可以用下面的链接:

百度云地址:https://pan.baidu.com/s/1aXEr1Et10SDn5jRT-DISpg

提取码:een3

下载好数据并解压,按上述文件格式放到相应的文件夹下即可。

(3)制作train.txt文件

最后一步就是制作txt文件了,参考代码中作者并没有给出制作文件的代码,自己就大概写了一个。也就是data2txt.py文件,后面会详细介绍到。

制作txt文件需要先放置好所有的数据,然后直接运行data2txt.py文件,就可以在'./data/training_data/'下生成train.txt文件。如果所有步骤都正确执行了,那么train.txt文件中的每一行都是相对应的一对影像,内容就是下面这样的形式:

3. data_provider文件夹下的所有文件

data_provider文件夹下面只有一个data_provider.py文件,下面直接给出该文件的代码:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : data_provider.py
import os.path as ops

import numpy as np
import cv2

from config import global_config

CFG = global_config.cfg


class DataSet(object):

    def __init__(self, dataset_info_file):
        self._gt_img_list, self._gt_label_list = self._init_dataset(dataset_info_file)
        self._random_dataset()
        self._next_batch_loop_count = 0

    def _init_dataset(self, dataset_info_file):

        gt_img_list = []
        gt_label_list = []

        assert ops.exists(dataset_info_file), '{:s} 不存在'.format(dataset_info_file)

        with open(dataset_info_file, 'r') as file:
            for _info in file:
                info_tmp = _info.strip(' ').split()

                gt_img_list.append(info_tmp[0])
                gt_label_list.append(info_tmp[1])
                print(gt_img_list[-1], gt_label_list[-1])
        # print(gt_img_list, gt_label_list)
        return gt_img_list, gt_label_list

    def _random_dataset(self):

        assert len(self._gt_img_list) == len(self._gt_label_list)

        random_idx = np.random.permutation(len(self._gt_img_list))
        new_gt_img_list = []
        new_gt_label_list = []

        for index in random_idx:
            new_gt_img_list.append(self._gt_img_list[index])
            new_gt_label_list.append(self._gt_label_list[index])

        self._gt_img_list = new_gt_img_list
        self._gt_label_list = new_gt_label_list


    def next_batch(self, batch_size):
        assert len(self._gt_label_list) == len(self._gt_img_list)

        idx_start = batch_size * self._next_batch_loop_count
        idx_end = batch_size * self._next_batch_loop_count + batch_size

        if idx_end > len(self._gt_label_list):
            self._random_dataset()
            self._next_batch_loop_count = 0
            return self.next_batch(batch_size)
        else:
            gt_img_list = self._gt_img_list[idx_start:idx_end]
            gt_label_list = self._gt_label_list[idx_start:idx_end]

            gt_imgs = []
            gt_labels = []
            mask_labels = []

            for index, gt_img_path in enumerate(gt_img_list):
                gt_image = cv2.imread(gt_img_path, cv2.IMREAD_COLOR)
                label_image = cv2.imread(gt_label_list[index], cv2.IMREAD_COLOR)

                gt_image = cv2.resize(gt_image, (CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT))
                label_image = cv2.resize(label_image, (CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT))

                diff_image = np.abs(np.array(cv2.cvtColor(gt_image, cv2.COLOR_BGR2GRAY), np.float32) -
                                    np.array(cv2.cvtColor(label_image, cv2.COLOR_BGR2GRAY), np.float32))

                mask_image = np.zeros(diff_image.shape, np.float32)

                mask_image[np.where(diff_image >= 30)] = 1

                gt_image = np.divide(gt_image, 127.5) - 1
                label_image = np.divide(label_image, 127.5) - 1

                gt_imgs.append(gt_image)
                gt_labels.append(label_image)
                mask_labels.append(mask_image)

            self._next_batch_loop_count += 1
            return gt_imgs, gt_labels, mask_labels

4. attentive_GAN_model文件夹下的所有文件

attentive_GAN_model文件夹下的文件比较多,自己大致看了一下,原代码写的挺好,我几乎没有做任何改动,只是删除掉了if main后面的代码。下面直接给出每个文件的代码:

attentive_GAN_net.py中的主要代码:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : attentive_gan_net.py

import tensorflow as tf

from attentive_gan_model import cnn_basenet
from attentive_gan_model import vgg16


class GenerativeNet(cnn_basenet.CNNBaseModel):

    def __init__(self, phase):

        super(GenerativeNet, self).__init__()
        self._vgg_extractor = vgg16.VGG16Encoder(phase='test')
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._test_phase = tf.constant('test', dtype=tf.string)
        self._phase = phase
        self._is_training = self._init_phase()

    def _init_phase(self):

        return tf.equal(self._phase, self._train_phase)


    def _residual_block(self, input_tensor, name):

        output = None
        with tf.variable_scope(name):
            for i in range(5):
                if i == 0:
                    conv_1 = self.conv2d(inputdata=input_tensor,
                                         out_channel=32,
                                         kernel_size=3,
                                         padding='SAME',
                                         stride=1,
                                         use_bias=False,
                                         name='block_{:d}_conv_1'.format(i))
                    relu_1 = self.lrelu(inputdata=conv_1, name='block_{:d}_relu_1'.format(i + 1))
                    output = relu_1
                    input_tensor = output
                else:
                    conv_1 = self.conv2d(inputdata=input_tensor,
                                         out_channel=32,
                                         kernel_size=1,
                                         padding='SAME',
                                         stride=1,
                                         use_bias=False,
                                         name='block_{:d}_conv_1'.format(i))
                    relu_1 = self.lrelu(inputdata=conv_1, name='block_{:d}_conv_1'.format(i + 1))
                    conv_2 = self.conv2d(inputdata=relu_1,
                                         out_channel=32,
                                         kernel_size=1,
                                         padding='SAME',
                                         stride=1,
                                         use_bias=False,
                                         name='block_{:d}_conv_2'.format(i))
                    relu_2 = self.lrelu(inputdata=conv_2, name='block_{:d}_conv_2'.format(i + 1))

                    output = self.lrelu(inputdata=tf.add(relu_2, input_tensor),
                                        name='block_{:d}_add'.format(i))
                    input_tensor = output

        return output

    def _conv_lstm(self, input_tensor, input_cell_state, name):

        with tf.variable_scope(name):
            conv_i = self.conv2d(inputdata=input_tensor, out_channel=32, kernel_size=3, padding='SAME',
                                 stride=1, use_bias=False, name='conv_i')
            sigmoid_i = self.sigmoid(inputdata=conv_i, name='sigmoid_i')

            conv_f = self.conv2d(inputdata=input_tensor, out_channel=32, kernel_size=3, padding='SAME',
                                 stride=1, use_bias=False, name='conv_f')
            sigmoid_f = self.sigmoid(inputdata=conv_f, name='sigmoid_f')

            cell_state = sigmoid_f * input_cell_state + \
                         sigmoid_i * tf.nn.tanh(self.conv2d(inputdata=input_tensor,
                                                            out_channel=32,
                                                            kernel_size=3,
                                                            padding='SAME',
                                                            stride=1,
                                                            use_bias=False,
                                                            name='conv_c'))
            conv_o = self.conv2d(inputdata=input_tensor, out_channel=32, kernel_size=3, padding='SAME',
                                 stride=1, use_bias=False, name='conv_o')
            sigmoid_o = self.sigmoid(inputdata=conv_o, name='sigmoid_o')

            lstm_feats = sigmoid_o * tf.nn.tanh(cell_state)

            attention_map = self.conv2d(inputdata=lstm_feats, out_channel=1, kernel_size=3, padding='SAME',
                                        stride=1, use_bias=False, name='attention_map')
            attention_map = self.sigmoid(inputdata=attention_map)

            ret = {
                'attention_map': attention_map,
                'cell_state': cell_state,
                'lstm_feats': lstm_feats
            }

            return ret

    def build_attentive_rnn(self, input_tensor, name):

        [batch_size, tensor_h, tensor_w, _] = input_tensor.get_shape().as_list()
        with tf.variable_scope(name):
            init_attention_map = tf.constant(0.5, dtype=tf.float32,
                                             shape=[batch_size, tensor_h, tensor_w, 1])
            init_cell_state = tf.constant(0.0, dtype=tf.float32,
                                          shape=[batch_size, tensor_h, tensor_w, 32])
            init_lstm_feats = tf.constant(0.0, dtype=tf.float32,
                                          shape=[batch_size, tensor_h, tensor_w, 32])

            attention_map_list = []

            for i in range(4):
                attention_input = tf.concat((input_tensor, init_attention_map), axis=-1)
                conv_feats = self._residual_block(input_tensor=attention_input,
                                                  name='residual_block_{:d}'.format(i + 1))
                lstm_ret = self._conv_lstm(input_tensor=conv_feats,
                                           input_cell_state=init_cell_state,
                                           name='conv_lstm_block_{:d}'.format(i + 1))
                init_attention_map = lstm_ret['attention_map']
                init_cell_state = lstm_ret['cell_state']
                init_lstm_feats = lstm_ret['lstm_feats']

                attention_map_list.append(lstm_ret['attention_map'])

        ret = {
            'final_attention_map': init_attention_map,
            'final_lstm_feats': init_lstm_feats,
            'attention_map_list': attention_map_list
        }

        return ret

    def compute_attentive_rnn_loss(self, input_tensor, label_tensor, name):

        with tf.variable_scope(name):
            inference_ret = self.build_attentive_rnn(input_tensor=input_tensor,
                                                     name='attentive_inference')
            loss = tf.constant(0.0, tf.float32)
            n = len(inference_ret['attention_map_list'])
            for index, attention_map in enumerate(inference_ret['attention_map_list']):
                mse_loss = tf.pow(0.8, n - index + 1) * \
                           tf.losses.mean_squared_error(labels=label_tensor,
                                                        predictions=attention_map)
                loss = tf.add(loss, mse_loss)

        return loss, inference_ret['final_attention_map']

    def build_autoencoder(self, input_tensor, name):

        with tf.variable_scope(name):
            conv_1 = self.conv2d(inputdata=input_tensor, out_channel=64, kernel_size=5,
                                 padding='SAME',
                                 stride=1, use_bias=False, name='conv_1')
            relu_1 = self.lrelu(inputdata=conv_1, name='relu_1')

            conv_2 = self.conv2d(inputdata=relu_1, out_channel=128, kernel_size=3,
                                 padding='SAME',
                                 stride=2, use_bias=False, name='conv_2')
            relu_2 = self.lrelu(inputdata=conv_2, name='relu_2')

            conv_3 = self.conv2d(inputdata=relu_2, out_channel=128, kernel_size=3,
                                 padding='SAME',
                                 stride=1, use_bias=False, name='conv_3')
            relu_3 = self.lrelu(inputdata=conv_3, name='relu_3')

            conv_4 = self.conv2d(inputdata=relu_3, out_channel=128, kernel_size=3,
                                 padding='SAME',
                                 stride=2, use_bias=False, name='conv_4')
            relu_4 = self.lrelu(inputdata=conv_4, name='relu_4')

            conv_5 = self.conv2d(inputdata=relu_4, out_channel=256, kernel_size=3,
                                 padding='SAME',
                                 stride=1, use_bias=False, name='conv_5')
            relu_5 = self.lrelu(inputdata=conv_5, name='relu_5')

            conv_6 = self.conv2d(inputdata=relu_5, out_channel=256, kernel_size=3,
                                 padding='SAME',
                                 stride=1, use_bias=False, name='conv_6')
            relu_6 = self.lrelu(inputdata=conv_6, name='relu_6')

            dia_conv1 = self.dilation_conv(input_tensor=relu_6, k_size=3, out_dims=256, rate=2,
                                           padding='SAME', use_bias=False, name='dia_conv_1')
            relu_7 = self.lrelu(dia_conv1, name='relu_7')

            dia_conv2 = self.dilation_conv(input_tensor=relu_7, k_size=3, out_dims=256, rate=4,
                                           padding='SAME', use_bias=False, name='dia_conv_2')
            relu_8 = self.lrelu(dia_conv2, name='relu_8')

            dia_conv3 = self.dilation_conv(input_tensor=relu_8, k_size=3, out_dims=256, rate=8,
                                           padding='SAME', use_bias=False, name='dia_conv_3')
            relu_9 = self.lrelu(dia_conv3, name='relu_9')

            dia_conv4 = self.dilation_conv(input_tensor=relu_9, k_size=3, out_dims=256, rate=16,
                                           padding='SAME', use_bias=False, name='dia_conv_4')
            relu_10 = self.lrelu(dia_conv4, name='relu_10')

            conv_7 = self.conv2d(inputdata=relu_10, out_channel=256, kernel_size=3,
                                 padding='SAME', stride=1, use_bias=False,
                                 name='conv_7')
            relu_11 = self.lrelu(inputdata=conv_7, name='relu_11')

            conv_8 = self.conv2d(inputdata=relu_11, out_channel=256, kernel_size=3,
                                 padding='SAME', stride=1, use_bias=False,
                                 name='conv_8')
            relu_12 = self.lrelu(inputdata=conv_8, name='relu_12')

            deconv_1 = self.deconv2d(inputdata=relu_12, out_channel=128, kernel_size=4,
                                     stride=2, padding='SAME', use_bias=False, name='deconv_1')
            avg_pool_1 = self.avgpooling(inputdata=deconv_1, kernel_size=2, stride=1, padding='SAME',
                                         name='avg_pool_1')
            relu_13 = self.lrelu(inputdata=avg_pool_1, name='relu_13')

            conv_9 = self.conv2d(inputdata=tf.add(relu_13, relu_3), out_channel=128, kernel_size=3,
                                 padding='SAME', stride=1, use_bias=False,
                                 name='conv_9')
            relu_14 = self.lrelu(inputdata=conv_9, name='relu_14')

            deconv_2 = self.deconv2d(inputdata=relu_14, out_channel=64, kernel_size=4,
                                     stride=2, padding='SAME', use_bias=False, name='deconv_2')
            avg_pool_2 = self.avgpooling(inputdata=deconv_2, kernel_size=2, stride=1, padding='SAME',
                                         name='avg_pool_2')
            relu_15 = self.lrelu(inputdata=avg_pool_2, name='relu_15')

            conv_10 = self.conv2d(inputdata=tf.add(relu_15, relu_1), out_channel=32, kernel_size=3,
                                  padding='SAME', stride=1, use_bias=False,
                                  name='conv_10')
            relu_16 = self.lrelu(inputdata=conv_10, name='relu_16')

            skip_output_1 = self.conv2d(inputdata=relu_12, out_channel=3, kernel_size=3,
                                        padding='SAME', stride=1, use_bias=False,
                                        name='skip_ouput_1')

            skip_output_2 = self.conv2d(inputdata=relu_14, out_channel=3, kernel_size=3,
                                        padding='SAME', stride=1, use_bias=False,
                                        name='skip_output_2')

            skip_output_3 = self.conv2d(inputdata=relu_16, out_channel=3, kernel_size=3,
                                        padding='SAME', stride=1, use_bias=False,
                                        name='skip_output_3')

            # 传统GAN输出层都使用tanh函数激活
            skip_output_3 = tf.nn.tanh(skip_output_3, name='skip_output_3_tanh')

            ret = {
                'skip_1': skip_output_1,
                'skip_2': skip_output_2,
                'skip_3': skip_output_3
            }

        return ret

    def compute_autoencoder_loss(self, input_tensor, label_tensor, name):

        [_, ori_height, ori_width, _] = label_tensor.get_shape().as_list()
        label_tensor_ori = label_tensor
        label_tensor_resize_2 = tf.image.resize_bilinear(images=label_tensor,
                                                         size=(int(ori_height / 2), int(ori_width / 2)))
        label_tensor_resize_4 = tf.image.resize_bilinear(images=label_tensor,
                                                         size=(int(ori_height / 4), int(ori_width / 4)))
        label_list = [label_tensor_resize_4, label_tensor_resize_2, label_tensor_ori]
        lambda_i = [0.6, 0.8, 1.0]
        # 计算lm_loss(见公式(5))
        lm_loss = tf.constant(0.0, tf.float32)
        with tf.variable_scope(name):
            inference_ret = self.build_autoencoder(input_tensor=input_tensor, name='autoencoder_inference')
            output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']]
            for index, output in enumerate(output_list):
                mse_loss = tf.losses.mean_squared_error(output, label_list[index]) * lambda_i[index]
                lm_loss = tf.add(lm_loss, mse_loss)

            # 计算lp_loss(见公式(6))
            src_vgg_feats = self._vgg_extractor.extract_feats(input_tensor=label_tensor,
                                                              name='vgg_feats',
                                                              reuse=False)
            pred_vgg_feats = self._vgg_extractor.extract_feats(input_tensor=output_list[-1],
                                                               name='vgg_feats',
                                                               reuse=True)

            lp_losses = []
            for index, feats in enumerate(src_vgg_feats):
                lp_losses.append(tf.losses.mean_squared_error(src_vgg_feats[index], pred_vgg_feats[index]))
            lp_loss = tf.reduce_mean(lp_losses)

            loss = tf.add(lm_loss, lp_loss)

        return loss, inference_ret['skip_3']

cnn_basenet.py中的主要代码:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://github.com/TJCVRS
#             @File    : cnn_basenet.py

import tensorflow as tf
import numpy as np


class CNNBaseModel(object):
    """
    Base model for other specific cnn ctpn_models
    """

    def __init__(self):
        pass

    @staticmethod
    def conv2d(inputdata, out_channel, kernel_size, padding='SAME',
               stride=1, w_init=None, b_init=None,
               split=1, use_bias=True, data_format='NHWC', name=None):
        """
        Packing the tensorflow conv2d function.
        :param name: op name
        :param inputdata: A 4D tensorflow tensor which ust have known number of channels, but can have other
        unknown dimensions.
        :param out_channel: number of output channel.
        :param kernel_size: int so only support square kernel convolution
        :param padding: 'VALID' or 'SAME'
        :param stride: int so only support square stride
        :param w_init: initializer for convolution weights
        :param b_init: initializer for bias
        :param split: split channels as used in Alexnet mainly group for GPU memory save.
        :param use_bias:  whether to use bias.
        :param data_format: default set to NHWC according tensorflow
        :return: tf.Tensor named ``output``
        """
        with tf.variable_scope(name):
            in_shape = inputdata.get_shape().as_list()
            channel_axis = 3 if data_format == 'NHWC' else 1
            in_channel = in_shape[channel_axis]
            assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
            assert in_channel % split == 0
            assert out_channel % split == 0

            padding = padding.upper()

            if isinstance(kernel_size, list):
                filter_shape = [kernel_size[0], kernel_size[1]] + [in_channel / split, out_channel]
            else:
                filter_shape = [kernel_size, kernel_size] + [in_channel / split, out_channel]

            if isinstance(stride, list):
                strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \
                    else [1, 1, stride[0], stride[1]]
            else:
                strides = [1, stride, stride, 1] if data_format == 'NHWC' \
                    else [1, 1, stride, stride]

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            w = tf.get_variable('W', filter_shape, initializer=w_init)
            b = None

            if use_bias:
                b = tf.get_variable('b', [out_channel], initializer=b_init)

            if split == 1:
                conv = tf.nn.conv2d(inputdata, w, strides, padding, data_format=data_format)
            else:
                inputs = tf.split(inputdata, split, channel_axis)
                kernels = tf.split(w, split, 3)
                outputs = [tf.nn.conv2d(i, k, strides, padding, data_format=data_format)
                           for i, k in zip(inputs, kernels)]
                conv = tf.concat(outputs, channel_axis)

            ret = tf.identity(tf.nn.bias_add(conv, b, data_format=data_format)
                              if use_bias else conv, name=name)

        return ret

    @staticmethod
    def relu(inputdata, name=None):

        return tf.nn.relu(features=inputdata, name=name)

    @staticmethod
    def sigmoid(inputdata, name=None):
 
        return tf.nn.sigmoid(x=inputdata, name=name)

    @staticmethod
    def maxpooling(inputdata, kernel_size, stride=None, padding='VALID',
                   data_format='NHWC', name=None):

        padding = padding.upper()

        if stride is None:
            stride = kernel_size

        if isinstance(kernel_size, list):
            kernel = [1, kernel_size[0], kernel_size[1], 1] if data_format == 'NHWC' else \
                [1, 1, kernel_size[0], kernel_size[1]]
        else:
            kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \
                else [1, 1, kernel_size, kernel_size]

        if isinstance(stride, list):
            strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \
                else [1, 1, stride[0], stride[1]]
        else:
            strides = [1, stride, stride, 1] if data_format == 'NHWC' \
                else [1, 1, stride, stride]

        return tf.nn.max_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding,
                              data_format=data_format, name=name)

    @staticmethod
    def avgpooling(inputdata, kernel_size, stride=None, padding='VALID',
                   data_format='NHWC', name=None):

        if stride is None:
            stride = kernel_size

        kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \
            else [1, 1, kernel_size, kernel_size]

        strides = [1, stride, stride, 1] if data_format == 'NHWC' else [1, 1, stride, stride]

        return tf.nn.avg_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding,
                              data_format=data_format, name=name)

    @staticmethod
    def globalavgpooling(inputdata, data_format='NHWC', name=None):

        assert inputdata.shape.ndims == 4
        assert data_format in ['NHWC', 'NCHW']

        axis = [1, 2] if data_format == 'NHWC' else [2, 3]

        return tf.reduce_mean(input_tensor=inputdata, axis=axis, name=name)

    @staticmethod
    def layernorm(inputdata, epsilon=1e-5, use_bias=True, use_scale=True,
                  data_format='NHWC', name=None):
        """
        :param name:
        :param inputdata:
        :param epsilon: epsilon to avoid divide-by-zero.
        :param use_bias: whether to use the extra affine transformation or not.
        :param use_scale: whether to use the extra affine transformation or not.
        :param data_format:
        :return:
        """
        shape = inputdata.get_shape().as_list()
        ndims = len(shape)
        assert ndims in [2, 4]

        mean, var = tf.nn.moments(inputdata, list(range(1, len(shape))), keep_dims=True)

        if data_format == 'NCHW':
            channnel = shape[1]
            new_shape = [1, channnel, 1, 1]
        else:
            channnel = shape[-1]
            new_shape = [1, 1, 1, channnel]
        if ndims == 2:
            new_shape = [1, channnel]

        if use_bias:
            beta = tf.get_variable('beta', [channnel], initializer=tf.constant_initializer())
            beta = tf.reshape(beta, new_shape)
        else:
            beta = tf.zeros([1] * ndims, name='beta')
        if use_scale:
            gamma = tf.get_variable('gamma', [channnel], initializer=tf.constant_initializer(1.0))
            gamma = tf.reshape(gamma, new_shape)
        else:
            gamma = tf.ones([1] * ndims, name='gamma')

        return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name)

    @staticmethod
    def instancenorm(inputdata, epsilon=1e-5, data_format='NHWC', use_affine=True, name=None):
        shape = inputdata.get_shape().as_list()
        if len(shape) != 4:
            raise ValueError("Input data of instancebn layer has to be 4D tensor")

        if data_format == 'NHWC':
            axis = [1, 2]
            ch = shape[3]
            new_shape = [1, 1, 1, ch]
        else:
            axis = [2, 3]
            ch = shape[1]
            new_shape = [1, ch, 1, 1]
        if ch is None:
            raise ValueError("Input of instancebn require known channel!")

        mean, var = tf.nn.moments(inputdata, axis, keep_dims=True)

        if not use_affine:
            return tf.divide(inputdata - mean, tf.sqrt(var + epsilon), name='output')

        beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
        beta = tf.reshape(beta, new_shape)
        gamma = tf.get_variable('gamma', [ch], initializer=tf.constant_initializer(1.0))
        gamma = tf.reshape(gamma, new_shape)
        return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name)

    @staticmethod
    def dropout(inputdata, keep_prob, noise_shape=None, name=None):
        return tf.nn.dropout(inputdata, keep_prob=keep_prob, noise_shape=noise_shape, name=name)

    @staticmethod
    def fullyconnect(inputdata, out_dim, w_init=None, b_init=None,
                     use_bias=True, name=None):
        """
        Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor.
        It is an equivalent of `tf.layers.dense` except for naming conventions.

        :param inputdata:  a tensor to be flattened except for the first dimension.
        :param out_dim: output dimension
        :param w_init: initializer for w. Defaults to `variance_scaling_initializer`.
        :param b_init: initializer for b. Defaults to zero
        :param use_bias: whether to use bias.
        :param name:
        :return: tf.Tensor: a NC tensor named ``output`` with attribute `variables`.
        """
        shape = inputdata.get_shape().as_list()[1:]
        if None not in shape:
            inputdata = tf.reshape(inputdata, [-1, int(np.prod(shape))])
        else:
            inputdata = tf.reshape(inputdata, tf.stack([tf.shape(inputdata)[0], -1]))

        if w_init is None:
            w_init = tf.contrib.layers.variance_scaling_initializer()
        if b_init is None:
            b_init = tf.constant_initializer()

        ret = tf.layers.dense(inputs=inputdata, activation=lambda x: tf.identity(x, name='output'),
                              use_bias=use_bias, name=name,
                              kernel_initializer=w_init, bias_initializer=b_init,
                              trainable=True, units=out_dim)
        return ret

    @staticmethod
    def layerbn(inputdata, is_training, name):
        with tf.variable_scope(name):
            return tf.layers.batch_normalization(inputs=inputdata, training=is_training)

    @staticmethod
    def squeeze(inputdata, axis=None, name=None):
        return tf.squeeze(input=inputdata, axis=axis, name=name)

    @staticmethod
    def deconv2d(inputdata, out_channel, kernel_size, padding='SAME',
                 stride=1, w_init=None, b_init=None,
                 use_bias=True, activation=None, data_format='channels_last',
                 trainable=True, name=None):
        """
        Packing the tensorflow conv2d function.
        :param name: op name
        :param inputdata: A 4D tensorflow tensor which ust have known number of channels, but can have other
        unknown dimensions.
        :param out_channel: number of output channel.
        :param kernel_size: int so only support square kernel convolution
        :param padding: 'VALID' or 'SAME'
        :param stride: int so only support square stride
        :param w_init: initializer for convolution weights
        :param b_init: initializer for bias
        :param activation: whether to apply a activation func to deconv result
        :param use_bias:  whether to use bias.
        :param data_format: default set to NHWC according tensorflow
        :return: tf.Tensor named ``output``
        """
        with tf.variable_scope(name):
            in_shape = inputdata.get_shape().as_list()
            channel_axis = 3 if data_format == 'channels_last' else 1
            in_channel = in_shape[channel_axis]
            assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"

            padding = padding.upper()

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            ret = tf.layers.conv2d_transpose(inputs=inputdata, filters=out_channel,
                                             kernel_size=kernel_size,
                                             strides=stride, padding=padding,
                                             data_format=data_format,
                                             activation=activation, use_bias=use_bias,
                                             kernel_initializer=w_init,
                                             bias_initializer=b_init, trainable=trainable,
                                             name=name)
        return ret

    @staticmethod
    def dilation_conv(input_tensor, k_size, out_dims, rate, padding='SAME',
                      w_init=None, b_init=None, use_bias=False, name=None):

        with tf.variable_scope(name):
            in_shape = input_tensor.get_shape().as_list()
            in_channel = in_shape[3]
            assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"

            padding = padding.upper()

            if isinstance(k_size, list):
                filter_shape = [k_size[0], k_size[1]] + [in_channel, out_dims]
            else:
                filter_shape = [k_size, k_size] + [in_channel, out_dims]

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            w = tf.get_variable('W', filter_shape, initializer=w_init)
            b = None

            if use_bias:
                b = tf.get_variable('b', [out_dims], initializer=b_init)

            conv = tf.nn.atrous_conv2d(value=input_tensor, filters=w, rate=rate,
                                       padding=padding, name='dilation_conv')

            if use_bias:
                ret = tf.add(conv, b)
            else:
                ret = conv

        return ret

    @staticmethod
    def spatial_dropout(input_tensor, keep_prob, is_training, name, seed=1234):
        tf.set_random_seed(seed=seed)

        def f1():
            with tf.variable_scope(name):
                return input_tensor

        def f2():
            with tf.variable_scope(name):
                num_feature_maps = [tf.shape(input_tensor)[0], tf.shape(input_tensor)[3]]

                random_tensor = keep_prob
                random_tensor += tf.random_uniform(num_feature_maps,
                                                   seed=seed,
                                                   dtype=input_tensor.dtype)

                binary_tensor = tf.floor(random_tensor)

                binary_tensor = tf.reshape(binary_tensor,
                                           [-1, 1, 1, tf.shape(input_tensor)[3]])
                ret = input_tensor * binary_tensor
                return ret

        output = tf.cond(is_training, f2, f1)
        return output

    @staticmethod
    def lrelu(inputdata, name, alpha=0.2):
        with tf.variable_scope(name):
            return tf.nn.relu(inputdata) - alpha * tf.nn.relu(-inputdata)

derain_drop_net.py中的主要代码:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : derain_drop_net.py

import tensorflow as tf

from attentive_gan_model import attentive_gan_net
from attentive_gan_model import discriminative_net


class DeRainNet(object):

    def __init__(self, phase):

        self._phase = phase
        self._attentive_gan = attentive_gan_net.GenerativeNet(self._phase)
        self._discriminator = discriminative_net.DiscriminativeNet(self._phase)

    def compute_loss(self, input_tensor, gt_label_tensor, mask_label_tensor, name):

        with tf.variable_scope(name):

            # 计算attentive rnn loss
            attentive_rnn_loss, attentive_rnn_output = self._attentive_gan.compute_attentive_rnn_loss(
                input_tensor=input_tensor,
                label_tensor=mask_label_tensor,
                name='attentive_rnn_loss')

            auto_encoder_input = tf.concat((attentive_rnn_output, input_tensor), axis=-1)

            auto_encoder_loss, auto_encoder_output = self._attentive_gan.compute_autoencoder_loss(
                input_tensor=auto_encoder_input,
                label_tensor=gt_label_tensor,
                name='attentive_autoencoder_loss'
            )

            gan_loss = tf.add(attentive_rnn_loss, auto_encoder_loss)

            discriminative_inference, discriminative_loss = self._discriminator.compute_loss(
                input_tensor=auto_encoder_output,
                label_tensor=gt_label_tensor,
                attention_map=attentive_rnn_output,
                name='discriminative_loss')

            l_gan = tf.reduce_mean(tf.log(tf.subtract(tf.constant(1.0), discriminative_inference)) * 0.01)

            gan_loss = tf.add(gan_loss, l_gan)

            return gan_loss, discriminative_loss, auto_encoder_output

    # 用于测试
    def build(self, input_tensor, name):

        with tf.variable_scope(name):

            attentive_rnn_out = self._attentive_gan.build_attentive_rnn(
                input_tensor=input_tensor,
                name='attentive_rnn_loss/attentive_inference'
            )

            attentive_autoencoder_input = tf.concat((attentive_rnn_out['final_attention_map'],
                                                     input_tensor), axis=-1)

            output = self._attentive_gan.build_autoencoder(
                input_tensor=attentive_autoencoder_input,
                name='attentive_autoencoder_loss/autoencoder_inference'
            )

            return output['skip_3'], attentive_rnn_out['attention_map_list']

discriminative_net.py中的主要代码:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : discriminative_net.py

import tensorflow as tf

from attentive_gan_model import cnn_basenet


class DiscriminativeNet(cnn_basenet.CNNBaseModel):

    def __init__(self, phase):

        super(DiscriminativeNet, self).__init__()
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._test_phase = tf.constant('test', dtype=tf.string)
        self._phase = phase
        self._is_training = self._init_phase()

    def _init_phase(self):

        return tf.equal(self._phase, self._train_phase)

    def _conv_stage(self, input_tensor, k_size, stride, out_dims, name):

        with tf.variable_scope(name):
            conv = self.conv2d(inputdata=input_tensor, out_channel=out_dims, kernel_size=k_size,
                               padding='SAME', stride=stride, use_bias=False, name='conv')

            relu = self.lrelu(conv, name='relu')

        return relu

    def build(self, input_tensor, name, reuse=False):

        with tf.variable_scope(name, reuse=reuse):
            conv_stage_1 = self._conv_stage(input_tensor=input_tensor, k_size=5,
                                            stride=1, out_dims=8,
                                            name='conv_stage_1')
            conv_stage_2 = self._conv_stage(input_tensor=conv_stage_1, k_size=5,
                                            stride=1, out_dims=16, name='conv_stage_2')
            conv_stage_3 = self._conv_stage(input_tensor=conv_stage_2, k_size=5,
                                            stride=1, out_dims=32, name='conv_stage_3')
            conv_stage_4 = self._conv_stage(input_tensor=conv_stage_3, k_size=5,
                                            stride=1, out_dims=64, name='conv_stage_4')
            conv_stage_5 = self._conv_stage(input_tensor=conv_stage_4, k_size=5,
                                            stride=1, out_dims=128, name='conv_stage_5')
            conv_stage_6 = self._conv_stage(input_tensor=conv_stage_5, k_size=5,
                                            stride=1, out_dims=128, name='conv_stage_6')
            attention_map = self.conv2d(inputdata=conv_stage_6, out_channel=1, kernel_size=5,
                                        padding='SAME', stride=1, use_bias=False, name='attention_map')
            conv_stage_7 = self._conv_stage(input_tensor=attention_map * conv_stage_6, k_size=5,
                                            stride=4, out_dims=64, name='conv_stage_7')
            conv_stage_8 = self._conv_stage(input_tensor=conv_stage_7, k_size=5,
                                            stride=4, out_dims=64, name='conv_stage_8')
            conv_stage_9 = self._conv_stage(input_tensor=conv_stage_8, k_size=5,
                                            stride=4, out_dims=32, name='conv_stage_9')
            fc_1 = self.fullyconnect(inputdata=conv_stage_9, out_dim=1024, use_bias=False, name='fc_1')
            fc_2 = self.fullyconnect(inputdata=fc_1, out_dim=1, use_bias=False, name='fc_2')
            fc_out = self.sigmoid(inputdata=fc_2, name='fc_out')

            fc_out = tf.where(tf.not_equal(fc_out, 1.0), fc_out, fc_out - 0.0000001)
            fc_out = tf.where(tf.not_equal(fc_out, 0.0), fc_out, fc_out + 0.0000001)

            return fc_out, attention_map, fc_2

    def compute_loss(self, input_tensor, label_tensor, attention_map, name):

        with tf.variable_scope(name):
            [batch_size, image_h, image_w, _] = input_tensor.get_shape().as_list()

            # 论文里的O
            zeros_mask = tf.zeros(shape=[batch_size, image_h, image_w, 1],
                                  dtype=tf.float32, name='O')
            fc_out_o, attention_mask_o, fc2_o = self.build(
                input_tensor=input_tensor, name='discriminative_inference')
            fc_out_r, attention_mask_r, fc2_r = self.build(
                input_tensor=label_tensor, name='discriminative_inference', reuse=True)

            l_map = tf.losses.mean_squared_error(attention_map, attention_mask_o) + \
                    tf.losses.mean_squared_error(attention_mask_r, zeros_mask)

            entropy_loss = -tf.log(fc_out_r) - tf.log(-tf.subtract(fc_out_o, tf.constant(1.0, tf.float32)))
            entropy_loss = tf.reduce_mean(entropy_loss)

            loss = entropy_loss + 0.05 * l_map

            return fc_out_o, loss

tf_ssim.py中的主要代码:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : tf_ssim.py

import tensorflow as tf
import numpy as np


class SsimComputer(object):

    def __init__(self):
        pass

    @staticmethod
    def _tf_fspecial_gauss(size, sigma):

        x_data, y_data = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]

        x_data = np.expand_dims(x_data, axis=-1)
        x_data = np.expand_dims(x_data, axis=-1)

        y_data = np.expand_dims(y_data, axis=-1)
        y_data = np.expand_dims(y_data, axis=-1)

        x = tf.constant(x_data, dtype=tf.float32)
        y = tf.constant(y_data, dtype=tf.float32)

        g = tf.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
        return g / tf.reduce_sum(g)

    def compute_ssim(self, img1, img2, cs_map=False, mean_metric=True, size=11, sigma=1.5):

        assert img1.get_shape().as_list()[-1] == 1, 'Image must be gray scale'
        assert img2.get_shape().as_list()[-1] == 1, 'Image must be gray scale'

        window = self._tf_fspecial_gauss(size, sigma)  # window shape [size, size]
        K1 = 0.01  # origin parameter in paper
        K2 = 0.03  # origin parameter in paper
        L = 1  # depth of image (255 in case the image has a differnt scale)
        C1 = (K1 * L) ** 2
        C2 = (K2 * L) ** 2
        mu1 = tf.nn.conv2d(img1, window, strides=[1, 1, 1, 1], padding='VALID')
        mu2 = tf.nn.conv2d(img2, window, strides=[1, 1, 1, 1], padding='VALID')
        mu1_sq = mu1 * mu1
        mu2_sq = mu2 * mu2
        mu1_mu2 = mu1 * mu2
        sigma1_sq = tf.nn.conv2d(img1 * img1, window, strides=[1, 1, 1, 1], padding='VALID') - mu1_sq
        sigma2_sq = tf.nn.conv2d(img2 * img2, window, strides=[1, 1, 1, 1], padding='VALID') - mu2_sq
        sigma12 = tf.nn.conv2d(img1 * img2, window, strides=[1, 1, 1, 1], padding='VALID') - mu1_mu2
        if cs_map:
            value = (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                                  (sigma1_sq + sigma2_sq + C2)),
                     (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2))
        else:
            value = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                                 (sigma1_sq + sigma2_sq + C2))

        if mean_metric:
            value = tf.reduce_mean(value)
        return value

vgg16.py中的主要代码:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : vgg16.py

import tensorflow as tf

from attentive_gan_model import cnn_basenet


class VGG16Encoder(cnn_basenet.CNNBaseModel):

    def __init__(self, phase):

        super(VGG16Encoder, self).__init__()
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._test_phase = tf.constant('test', dtype=tf.string)
        self._phase = phase
        self._is_training = self._init_phase()
        print('VGG16 Network init complete')

    def _init_phase(self):

        return tf.equal(self._phase, self._train_phase)

    def _conv_stage(self, input_tensor, k_size, out_dims, name,
                    stride=1, pad='SAME', reuse=False):

        with tf.variable_scope(name, reuse=reuse):
            conv = self.conv2d(inputdata=input_tensor, out_channel=out_dims,
                               kernel_size=k_size, stride=stride,
                               use_bias=False, padding=pad, name='conv')
            relu = self.relu(inputdata=conv, name='relu')

            return relu

    def _fc_stage(self, input_tensor, out_dims, name, use_bias=False, reuse=False):

        with tf.variable_scope(name, reuse=reuse):
            fc = self.fullyconnect(inputdata=input_tensor, out_dim=out_dims, use_bias=use_bias,
                                   name='fc')
            relu = self.relu(inputdata=fc, name='relu')

        return relu

    def extract_feats(self, input_tensor, name, reuse=False):
        with tf.variable_scope(name, reuse=reuse):
            # conv stage 1_1
            conv_1_1 = self._conv_stage(input_tensor=input_tensor, k_size=3,
                                        out_dims=64, name='conv1_1')

            # conv stage 1_2
            conv_1_2 = self._conv_stage(input_tensor=conv_1_1, k_size=3,
                                        out_dims=64, name='conv1_2')

            # pool stage 1
            pool1 = self.maxpooling(inputdata=conv_1_2, kernel_size=2,
                                    stride=2, name='pool1')

            # conv stage 2_1
            conv_2_1 = self._conv_stage(input_tensor=pool1, k_size=3,
                                        out_dims=128, name='conv2_1')

            # conv stage 2_2
            conv_2_2 = self._conv_stage(input_tensor=conv_2_1, k_size=3,
                                        out_dims=128, name='conv2_2')

            # pool stage 2
            pool2 = self.maxpooling(inputdata=conv_2_2, kernel_size=2,
                                    stride=2, name='pool2')

            # conv stage 3_1
            conv_3_1 = self._conv_stage(input_tensor=pool2, k_size=3,
                                        out_dims=256, name='conv3_1')

            # conv_stage 3_2
            conv_3_2 = self._conv_stage(input_tensor=conv_3_1, k_size=3,
                                        out_dims=256, name='conv3_2')

            # conv stage 3_3
            conv_3_3 = self._conv_stage(input_tensor=conv_3_2, k_size=3,
                                        out_dims=256, name='conv3_3')

            ret = (conv_1_1, conv_1_2, conv_2_1, conv_2_2,
                   conv_3_1, conv_3_2, conv_3_3)

        return ret

5. config文件夹下的所有文件

config文件夹下只有一个global_config.py文件,这里面保存了网络参数。注:如果GPU的内存比较小的话,建议把参数改小一些。我是用的GTX1060 3G内存,直接运行都会内存溢出,所以可以考虑把图像尺寸改小一点。下面先给出文件中的参数设置(我只改了一点点):

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : global_config.py
from easydict import EasyDict as edict

__C = edict()
# Consumers can get config by: from config import cfg

cfg = __C

# Train options
__C.TRAIN = edict()

__C.TRAIN.EPOCHS = 20010
__C.TRAIN.LEARNING_RATE = 0.002
# Set the GPU resource used during training process
__C.TRAIN.GPU_MEMORY_FRACTION = 0.95
# Set the GPU allow growth parameter during tensorflow training process
__C.TRAIN.TF_ALLOW_GROWTH = True
__C.TRAIN.BATCH_SIZE = 1
__C.TRAIN.IMG_HEIGHT = 240
__C.TRAIN.IMG_WIDTH = 360

# Test options
__C.TEST = edict()

# Set the GPU resource used during testing process
__C.TEST.GPU_MEMORY_FRACTION = 0.8
# Set the GPU allow growth parameter during tensorflow testing process
__C.TEST.TF_ALLOW_GROWTH = True
__C.TEST.BATCH_SIZE = 1
__C.TEST.IMG_HEIGHT = 240
__C.TEST.IMG_WIDTH = 360

6. data2txt.py文件

该文件是自己写的,用于生成train.txt文件,直接给出代码:

import os


def data2txt(data_rootdir):
    # 读取两个文件夹的所有图像并判断是否相等
    images = os.listdir(data_rootdir + 'data/')
    labels = os.listdir(data_rootdir + 'gt/')
    images.sort()
    labels.sort()

    image_len = len(images)
    label_len = len(labels)

    assert image_len == label_len

    # 打开文本并写入路径
    trainText = open(data_rootdir + 'train.txt', 'w')
    for i in range(image_len):
        image_dir = data_rootdir + 'data/' + images[i] + ' '
        label_dir = data_rootdir + 'gt/' + labels[i] + '\n'

        trainText.write(image_dir)
        trainText.write(label_dir)

    trainText.close()
    print('finished!')


if __name__ == '__main__':
    data2txt('./data/training_data/')

7. train_model.py文件

train_model.py用于训练文件,训练之前主要是检查参数的设置,设置好直接运行该文件即可开始训练,下面给出该文件的代码:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : train_model.py.py

import os
import os.path as ops
import argparse
import time

import tensorflow as tf
import numpy as np
import glog as log

from data_provider import data_provider
from config import global_config
from attentive_gan_model import derain_drop_net
from attentive_gan_model import tf_ssim

CFG = global_config.cfg
VGG_MEAN = [103.939, 116.779, 123.68]


def init_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_dir', type=str, default='./data/training_data/', help='The dataset dir')
    parser.add_argument('--weights_path', type=str,
                        # default='',
                        help='The pretrained weights path', default=None)

    return parser.parse_args()


def train_model(dataset_dir, weights_path=None):

    # 构建数据集
    with tf.device('/gpu:0'):
        train_dataset = data_provider.DataSet(ops.join(dataset_dir, 'train.txt'))

        # 声明tensor
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3],
                                      name='input_tensor')
        label_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3],
                                      name='label_tensor')
        mask_tensor = tf.placeholder(dtype=tf.float32,
                                     shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 1],
                                     name='mask_tensor')
        lr_tensor = tf.placeholder(dtype=tf.float32,
                                   shape=[],
                                   name='learning_rate')
        phase_tensor = tf.placeholder(dtype=tf.string, shape=[], name='phase')

        # 声明ssim计算类
        ssim_computer = tf_ssim.SsimComputer()

        # 声明网络
        derain_net = derain_drop_net.DeRainNet(phase=phase_tensor)

        gan_loss, discriminative_loss, net_output = derain_net.compute_loss(
            input_tensor=input_tensor,
            gt_label_tensor=label_tensor,
            mask_label_tensor=mask_tensor,
            name='derain_net_loss')

        train_vars = tf.trainable_variables()

        ssim = ssim_computer.compute_ssim(tf.image.rgb_to_grayscale(net_output),
                                          tf.image.rgb_to_grayscale(label_tensor))

        d_vars = [tmp for tmp in train_vars if 'discriminative_loss' in tmp.name]
        g_vars = [tmp for tmp in train_vars if 'attentive_' in tmp.name and 'vgg_feats' not in tmp.name]
        vgg_vars = [tmp for tmp in train_vars if "vgg_feats" in tmp.name]

        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(lr_tensor, global_step, 100000, 0.1, staircase=True)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            d_optim = tf.train.AdamOptimizer(learning_rate).minimize(discriminative_loss, var_list=d_vars)
            g_optim = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=tf.constant(0.9, tf.float32)).minimize(gan_loss, var_list=g_vars)

        # Set tf saver
        saver = tf.train.Saver()
        model_save_dir = './model/derain_gan_tensorflow'
        if not ops.exists(model_save_dir):
            os.makedirs(model_save_dir)
        train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
        model_name = 'derain_gan_{:s}.ckpt'.format(str(train_start_time))
        model_save_path = ops.join(model_save_dir, model_name)

        # Set tf summary
        tboard_save_path = './tboard/derain_gan_tensorflow'
        if not ops.exists(tboard_save_path):
            os.makedirs(tboard_save_path)
        g_loss_scalar = tf.summary.scalar(name='gan_loss', tensor=gan_loss)
        d_loss_scalar = tf.summary.scalar(name='discriminative_loss', tensor=discriminative_loss)
        ssim_scalar = tf.summary.scalar(name='image_ssim', tensor=ssim)
        lr_scalar = tf.summary.scalar(name='learning_rate', tensor=lr_tensor)
        d_summary_op = tf.summary.merge([d_loss_scalar, lr_scalar])
        g_summary_op = tf.summary.merge([g_loss_scalar, ssim_scalar])

        # Set sess configuration
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
        sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
        sess_config.gpu_options.allocator_type = 'BFC'

        sess = tf.Session(config=sess_config)

        summary_writer = tf.summary.FileWriter(tboard_save_path)
        summary_writer.add_graph(sess.graph)

        # Set the training parameters
        train_epochs = CFG.TRAIN.EPOCHS

        log.info('Global configuration is as follows:')
        log.info(CFG)

        with sess.as_default():

            tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='',
                                 name='{:s}/derain_gan.pb'.format(model_save_dir))

            if weights_path is None:
                log.info('Training from scratch')
                init = tf.global_variables_initializer()
                sess.run(init)
            else:
                log.info('Restore model from last model checkpoint {:s}'.format(weights_path))
                saver.restore(sess=sess, save_path=weights_path)

            # 加载预训练参数
            pretrained_weights = np.load('./data/vgg16.npy', encoding='latin1').item()

            for vv in vgg_vars:
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception:
                    continue

            # train loop
            for epoch in range(train_epochs):
                # training part
                t_start = time.time()

                gt_imgs, label_imgs, mask_imgs = train_dataset.next_batch(CFG.TRAIN.BATCH_SIZE)

                mask_imgs = [np.expand_dims(tmp, axis=-1) for tmp in mask_imgs]

                # Update discriminative Network
                _, d_loss, d_summary = sess.run(
                    [d_optim, discriminative_loss, d_summary_op],
                    feed_dict={input_tensor: gt_imgs,
                               label_tensor: label_imgs,
                               mask_tensor: mask_imgs,
                               lr_tensor: CFG.TRAIN.LEARNING_RATE,
                               phase_tensor: 'train'})

                # Update attentive gan Network
                _, g_loss, g_summary, ssim_val = sess.run(
                    [g_optim, gan_loss, g_summary_op, ssim],
                    feed_dict={input_tensor: gt_imgs,
                               label_tensor: label_imgs,
                               mask_tensor: mask_imgs,
                               lr_tensor: CFG.TRAIN.LEARNING_RATE,
                               phase_tensor: 'train'})

                summary_writer.add_summary(d_summary, global_step=epoch)
                summary_writer.add_summary(g_summary, global_step=epoch)

                cost_time = time.time() - t_start

                log.info('Epoch: {:d} D_loss: {:.5f} G_loss: '
                         '{:.5f} Ssim: {:.5f} Cost_time: {:.5f}s'.format(epoch, d_loss, g_loss,
                                                                         ssim_val, cost_time))
                if epoch % 500 == 0:
                    saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
        sess.close()

    return


if __name__ == '__main__':
    args = init_args()
    train_model(args.dataset_dir, weights_path=args.weights_path)

8. test_model.py文件

训练好模型之后,打开test_model.py文件,设置相应的参数(主要是weights),然后直接运行就可以了:

# cited from:
#             @Author  : Luo Yao
#             @Site    : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
#             @File    : test_model.py

import os.path as ops
import argparse

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import cv2
from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr

from attentive_gan_model import derain_drop_net
from config import global_config

CFG = global_config.cfg


def init_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--image_path', type=str,
                        default='./data/test_data/27_rain.jpg',
                        help='The input image path')
    parser.add_argument('--weights_path', type=str,
                        default='./model/new_model/derain_gan_2018-11-02-19-55-27.ckpt-200000',
                        help='The model weights path')

    return parser.parse_args()


def minmax_scale(input_arr):

    min_val = np.min(input_arr)
    max_val = np.max(input_arr)

    output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)

    return output_arr


def test_model(image_path, weights_path):

    assert ops.exists(image_path)

    with tf.device('/gpu:0'):
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[CFG.TEST.BATCH_SIZE, CFG.TEST.IMG_HEIGHT, CFG.TEST.IMG_WIDTH, 3],
                                      name='input_tensor'
                                      )

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (CFG.TEST.IMG_WIDTH, CFG.TEST.IMG_HEIGHT))
    image_vis = image
    image = np.divide(image, 127.5) - 1

    phase = tf.constant('test', tf.string)

    with tf.device('/gpu:0'):
        net = derain_drop_net.DeRainNet(phase=phase)
        output, attention_maps = net.build(input_tensor=input_tensor, name='derain_net_loss')

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    saver = tf.train.Saver()

    with tf.device('/gpu:0'):
        with sess.as_default():
            saver.restore(sess=sess, save_path=weights_path)

            output_image, atte_maps = sess.run(
                [output, attention_maps],
                feed_dict={input_tensor: np.expand_dims(image, 0)})

            output_image = output_image[0]
            for i in range(output_image.shape[2]):
                output_image[:, :, i] = minmax_scale(output_image[:, :, i])

            output_image = np.array(output_image, np.uint8)

            # Image metrics计算
            image_ssim = ssim(
                image_vis,
                output_image,
                data_range=output_image.max() - output_image.min(),
                multichannel=True)
            image_psnr = psnr(
                image_vis,
                output_image,
                data_range=output_image.max() - output_image.min())

            print('Image ssim: {:.6f}'.format(image_ssim))
            print('Image psnr: {:.6f}'.format(image_psnr))

            # 保存并可视化结果
            cv2.imwrite('src_img.png', image_vis)
            cv2.imwrite('derain_ret.png', output_image)

            plt.figure('src_image')
            plt.imshow(image_vis[:, :, (2, 1, 0)])
            plt.figure('derain_ret')
            plt.imshow(output_image[:, :, (2, 1, 0)])
            plt.figure('atte_map_1')
            plt.imshow(atte_maps[0][0, :, :, 0], cmap='jet')
            plt.savefig('atte_map_1.png')
            plt.figure('atte_map_2')
            plt.imshow(atte_maps[1][0, :, :, 0], cmap='jet')
            plt.savefig('atte_map_2.png')
            plt.figure('atte_map_3')
            plt.imshow(atte_maps[2][0, :, :, 0], cmap='jet')
            plt.savefig('atte_map_3.png')
            plt.figure('atte_map_4')
            plt.imshow(atte_maps[3][0, :, :, 0], cmap='jet')
            plt.savefig('atte_map_4.png')
            plt.show()


if __name__ == '__main__':
    args = init_args()
    test_model(args.image_path, args.weights_path)

四、attentive GAN实现过程

1. 安装必要的开发环境

有两个库不太常用,如果没有用过这两个库的话需要自己手动安装一下:

glog==0.3.1

easydict==1.6

这两个库无法在pycharm中快捷安装,所以就打开Git,直接pip install吧,安装好了之后会有提示:

2. 训练

整个训练过程都只用到train_model.py和data2txt.py文件。我们需要把数据放到对应的位置,然后直接执行data2txt.py文件,就会在'./data/training_data/'下生成train.txt文件,这个文件中包含了我们要用的所有成对的训练数据。

(1)从零开始训练

如果是从零开始训练的话,那我们打开train_model.py文件,设置参数'dataset_dir'的路径就可以了,'weights_path'不要设置任何值,然后执行train_model.py文件,就可以开始训练了。

(2)从之前的结果开始训练

当然,如果已经训练了一段时间,中断了程序后想要继续训练。那么这时候我们只需要设置'weights_path'的默认值就行了,找到之前训练的保存结果,然后设置成default,例如:

    parser.add_argument('--weights_path', type=str,
                        default='./model/derain_GAN_tensorflow10/checkpoint',
                        help='The pretrained weights path', default=None)

设置好之后直接开始执行train_model.py,即可接着之前的结果继续实验。如果程序没问题的话,就会像下面这样:

3. 测试

打开test_model.py,添加相应的图片路径和训练好的模型路径,可以用default的形式添加,例如:

    parser.add_argument('--image_path', type=str,
                        default='./data/test_data/test_2.png',
                        help='The input image path')
    parser.add_argument('--weights_path', type=str,
                        default='./model/new_model/derain_gan_2018-11-02-19-55-27.ckpt-200000',
                        help='The model weights path')

添加好之后直接运行,即可在根目录下看到生成的结果和4张attention map。

五、实验结果

当然,整个训练过程比较漫长,如果不想等待的话,可以使用作者训练好的模型,直接在github上就可以看到。同样的,导入要测试的图片,不过一次只能测试一张,下面来看一下模型的结果(我是用的作者训练好的模型):

当雨点比较小的时候,可以看到结果几乎能够非常清晰的复原,当雨点比较大的时候,虽然无法完全复原原图,不过也能够比较好的去除雨点。

六、分析

1. attentive GAN引入了attention map来辅助识别雨点。

猜你喜欢

转载自blog.csdn.net/z704630835/article/details/84616685