【深度学习】FCN 图像语义分割训练 Sift-flow Dataset + Batch Normalization 极大加快收敛速度 (从零开始训练 FCN,没有使用 VGG 权值)

FCN 图像语义分割训练 Sift-flow Dataset + Batch Normalization

前言

上一篇博客中,我写了个 FCN ,训练了近 30 个小时才能微微看到效果。
又说到了本人很懒嘛,那怎么可能想等那么久呢。于是就想着加个 BN,百度一下,发现没有人这么干。那我肯定说干就干,于是发现 30 小时的工作量,半小时就能完成。。。
经过实验室同学提醒,FCN 出来的时候 BN 还没面世。结果发现,BN 的论文 2015.3.2 发表,而 FCN 的在 2015.8.8 发表,两个神器居然就差了这么几天。
加了 BN 层之后,收敛效果更快更好,为了调参,我还加了 tf.train.exponential_decay 使学习率自动衰减。
在这里插入图片描述
效果和上一篇博客几乎一样,可以参考一下。
网络图如下:
在这里插入图片描述

代码

#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
os.system("rm -r logs")
import tensorflow as tf
import matplotlib.pyplot as plt


# In[ ]:


trainPath = '/home/winsoul/disk/Segmentation/SiftFlow/data/GeoLabels/tfrecords/train.tfrecords'
testPath = '/home/winsoul/disk/Segmentation/SiftFlow/data/GeoLabels/tfrecords/test.tfrecords'
valPath = '/home/winsoul/disk/Segmentation/SiftFlow/data/GeoLabels/tfrecords/val.tfrecords'

model_path = '/home/winsoul/disk/Segmentation/SiftFlow/FCN_modelUseBN/model_backup/'
DisplayStep = 25
ModelSaverStep = 2500
decay_step = 500
decay_rate = 0.995


# In[ ]:


def read_tfrecords(TFRecordsPath):
    with tf.Session() as sess:
        feature = {
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.string),
            'name': tf.FixedLenFeature([], tf.string),
        }
        filename_queue = tf.train.string_input_producer([TFRecordsPath])
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(serialized_example, features = feature)
        
        image = tf.decode_raw(features['image'], tf.float32)
        image = tf.reshape(image, [224, 224, 3])
        
        label = tf.decode_raw(features['label'], tf.uint8)
        label = tf.reshape(label, [224, 224])
        
        return image, label


# In[ ]:


def conv_layer(X, k, s, channels_in, channels_out, name = 'CONV', padding = 'SAME', is_training = True):
    with tf.name_scope(name):
        W = tf.Variable(tf.truncated_normal([k, k, channels_in, channels_out], stddev = 0.1));
        b = tf.Variable(tf.constant(0.1, shape = [channels_out]))
        conv = tf.nn.conv2d(X, W, strides = [1, s, s, 1], padding = padding)
        bn = tf.layers.batch_normalization(conv + b, training = is_training)
        result = tf.nn.relu(bn)
#         tf.summary.histogram('weights', W)
#         tf.summary.histogram('biases', b)
#         tf.summary.histogram('activations', result)
        return result
    
def pool_layer(X, k, s, strr = 'SAME', pool_type = 'MAX', name = 'pool'):
    with tf.name_scope(name):
        if pool_type == 'MAX':
            result = tf.nn.max_pool(X,
                                  ksize = [1, k, k, 1],
                                  strides = [1, s, s, 1],
                                  padding = strr, name = name)
        else:
            result = tf.nn.avg_pool(X,
                                  ksize = [1, k, k, 1],
                                  strides = [1, s, s, 1],
                                  padding = strr, name = name)
        return result

def fc_layer(X, neurons_in, neurons_out, last = False, name = 'FC'):
    with tf.name_scope(name):
        W = tf.Variable(tf.truncated_normal([neurons_in, neurons_out], stddev = 0.1))
        b = tf.Variable(tf.constant(0.1, shape = [neurons_out]))
#         tf.summary.histogram('weights', W)
#         tf.summary.histogram('biases', b)
        if last == False:
            result = tf.nn.relu(tf.matmul(X, W) + b)
        else:
            result =  tf.nn.softmax(tf.matmul(X, W) + b)
#         tf.summary.histogram('activations', result)
        return result


# In[ ]:


def conv_transpose_layer(X, k, s, input_shape, output_shape, name = 'CONV_TRAN', padding = 'SAME'):        
    with tf.name_scope(name):
        W = tf.Variable(tf.truncated_normal([k, k, output_shape[3].value, input_shape[3].value], stddev = 0.1));
        b = tf.Variable(tf.constant(0.1, shape = [output_shape[3].value]))
        deconv = tf.nn.conv2d_transpose(X, W, output_shape, strides=[1, s, s, 1], padding = "SAME")
        result = tf.add(deconv, b)
        result = deconv
#         tf.summary.histogram('weights', W)
#         tf.summary.histogram('biases', b)
#         tf.summary.histogram('activations', result)
        return result


# In[ ]:


def Network(BatchSize, start_learning_rate):
    tf.reset_default_graph()

    with tf.Session() as sess:
        is_training = tf.placeholder(dtype = tf.bool)
        keep_prob = tf.placeholder(dtype = tf.float32)
        global_step = tf.placeholder(dtype = tf.int32)
        origin_image = tf.placeholder(tf.uint8, shape=([BatchSize, 224, 224, 4]))
        y_label = tf.placeholder(tf.int32, shape=[None, 224, 224, 1], name="y_label")
        
        image_train, label_train = read_tfrecords(trainPath)
        image_val, label_val = read_tfrecords(valPath)
        
        image_train_batch, label_train_batch = tf.train.shuffle_batch([image_train, label_train],
                                                         batch_size = BatchSize,
                                                         capacity = BatchSize * 3 + 200,
                                                         min_after_dequeue = BatchSize)
        image_val_batch, label_val_batch = tf.train.shuffle_batch([image_val, label_val],
                                                         batch_size = BatchSize,
                                                         capacity = BatchSize * 3 + 200,
                                                         min_after_dequeue = BatchSize)
        image_Batch = tf.cond(is_training, lambda: image_train_batch, lambda: image_val_batch)
        label_Batch = tf.cond(is_training, lambda: label_train_batch, lambda: label_val_batch)
        
        X = tf.identity(image_Batch)
        y = tf.identity(label_Batch)
        y = tf.cast(y, tf.int32)
            
        conv1_1 = conv_layer(X, 3, 1, 3, 64, "conv1_1")
        conv1_2 = conv_layer(conv1_1, 3, 1, 64, 64, "conv1_2")
        pool1 = pool_layer(conv1_2, 2, 2, "SAME", "MAX", "pool1")
        
        conv2_1 = conv_layer(pool1, 3, 1, 64, 128, "conv2_1")
        conv2_2 = conv_layer(conv2_1, 3, 1, 128, 128, "conv2_2")
        pool2 = pool_layer(conv2_2, 2, 2, "SAME", "MAX", 'pool2')
        
        conv3_1 = conv_layer(pool2, 3, 1, 128, 256, "conv3_1")
        conv3_2 = conv_layer(conv3_1, 3, 1, 256, 256, "conv3_2")
        conv3_3 = conv_layer(conv3_2, 3, 1, 256, 256, "conv3_3")
        pool3 = pool_layer(conv3_3, 2, 2, "SAME", "MAX", 'pool3')
        
        conv4_1 = conv_layer(pool3, 3, 1, 256, 512, "conv4_1")
        conv4_2 = conv_layer(conv4_1, 3, 1, 512, 512, "conv4_2")
        conv4_3 = conv_layer(conv4_2, 3, 1, 512, 512, "conv4_3")
        pool4 = pool_layer(conv4_3, 2, 2, "SAME", "MAX", 'pool4')
        print(pool4)
        
        conv5_1 = conv_layer(pool4, 3, 1, 512, 512, "conv5_1")
        conv5_2 = conv_layer(conv5_1, 3, 1, 512, 512, "conv5_2")
        conv5_3 = conv_layer(conv5_2, 3, 1, 512, 512, "conv5_3")
        pool5 = pool_layer(conv5_3, 2, 2, "SAME", "MAX", 'pool5')
        print(pool5)
        
        conv6_1 = conv_layer(pool5, 7, 1, 512, 1024, "conv6_1")
        conv6_2 = conv_layer(conv6_1, 1, 1, 1024, 512, "conv6_2")
        conv6_3 = conv_layer(conv6_2, 1, 1, 512, 4, "conv6_3")
        drop1 = tf.nn.dropout(conv6_3, keep_prob)
        print(drop1)
        
        deconv1 = conv_transpose_layer(drop1, 4, 2, conv6_3.get_shape(), pool4.get_shape(), name = 'CONV_TRAN_1')
        fuse1 = tf.add(deconv1, pool4, name = 'fuse_1')
        print(fuse1)
        
        deconv2 = conv_transpose_layer(fuse1, 4, 2, pool4.get_shape(), pool3.get_shape(), name = 'CONV_TRAN_2')
        fuse2 = tf.add(deconv2, pool3, name = 'fuse_2')
        print(fuse2)
        
        deconv3 = conv_transpose_layer(fuse2, 16, 8, pool3.get_shape(), origin_image.get_shape(), name = 'CONV_TRAN_3')
        print(deconv3)
        
        y_result = tf.argmax(deconv3, dimension = 3, name = 'y_result')
        print(y_result)
        
        y_result = tf.cast(y_result, tf.int32)
        with tf.name_scope('input'):
            tf.summary.image('input', X, BatchSize)
            
        with tf.name_scope('output'):  
            y_paint = tf.cast(y_result, tf.uint8)
            y_paint = (y_paint * (y_paint + 6)) * 9
            tf.summary.image('output', tf.reshape(y_paint, [-1, 224, 224, 1]), BatchSize)

        
        with tf.name_scope('summaries'):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops): 
                cross_entropy = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits = deconv3,
                                                                                               labels = y,
                                                                                               name = "cross_entropy")))
                learning_rate = tf.train.exponential_decay(start_learning_rate, global_step, decay_step, decay_rate, staircase = True)
                train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
                #train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
                corrent_prediction = tf.equal(y_result, y)
                accuracy = tf.reduce_mean(tf.cast(corrent_prediction, 'float', name = 'accuracy'))
                tf.summary.scalar("loss", cross_entropy)
                tf.summary.scalar("accuracy", accuracy)
        with tf.name_scope('learning_rate'):
            tf.summary.scalar("learning_rate", learning_rate)
        
        
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord = coord)
        
        merge_summary = tf.summary.merge_all()
        summary__train_writer = tf.summary.FileWriter("./logs/train" + '_rate:' + str(start_learning_rate), sess.graph)
        summary_val_writer = tf.summary.FileWriter("./logs/test" + '_rate:' + str(start_learning_rate))
        
        saver = tf.train.Saver()
#         saver.restore(sess, model_path + 'Model_rate_1e-5__Step_00020000')
        
        try:
            batch_index = 1
            while not coord.should_stop():
                sess.run(train_step, feed_dict = {is_training: True, keep_prob: 0.5, global_step: batch_index})
                if batch_index % 25 == 0:
                    summary_train, acc_train, loss_train, _ = sess.run([merge_summary, accuracy, cross_entropy, train_step], feed_dict = {is_training: True, keep_prob: 0.5, global_step: batch_index})   
                    summary__train_writer.add_summary(summary_train, batch_index) 
                    print(str(batch_index) + ' train:' + '  ' + str(acc_train) + ' ' + str(loss_train), end = '   ')
                    summary_val, acc_val, loss_val = sess.run([merge_summary, accuracy, cross_entropy], feed_dict = {is_training: False, keep_prob: 1.0, global_step: batch_index}) 
                    summary_val_writer.add_summary(summary_val, batch_index) 
                    print('  val: ' + '  ' + str(acc_val) + ' ' + str(loss_val))
                if batch_index % ModelSaverStep == 0:
                    save_path = saver.save(sess, model_path + '/newModel/Model_rate_1e-5__Step_{:08d}'.format(batch_index))
                    
                
                batch_index += 1;

#                 if batch_index > 1000:
#                     break;
#                 for i in range(BatchSize):
#                     plt.imshow(ans[0], cmap = 'gray')
#                     plt.show()
                
        except tf.errors.OutOfRangeError:
            print("OutofRangeError!")
    
        coord.request_stop()
        coord.join(threads)
        sess.close()


# In[ ]:


def main():
    rate = 1e-5
    while True:
        print('-----------------------------------------------------')
        print('Batch: 16      learning_rate:', rate)
        try:
            Network(16, rate)
        except KeyboardInterrupt:
            pass
        rate /= 3

if __name__ == '__main__':
    main()

发布了79 篇原创文章 · 获赞 56 · 访问量 50万+

猜你喜欢

转载自blog.csdn.net/qq_40861916/article/details/100589613
今日推荐