FCN 图像语义分割训练 Sift-flow Dataset + Batch Normalization
前言
在上一篇博客中,我写了个 FCN ,训练了近 30 个小时才能微微看到效果。
又说到了本人很懒嘛,那怎么可能想等那么久呢。于是就想着加个 BN,百度一下,发现没有人这么干。那我肯定说干就干,于是发现 30 小时的工作量,半小时就能完成。。。
经过实验室同学提醒,FCN 出来的时候 BN 还没面世。结果发现,BN 的论文 2015.3.2 发表,而 FCN 的在 2015.8.8 发表,两个神器居然就差了这么几天。
加了 BN 层之后,收敛效果更快更好,为了调参,我还加了 tf.train.exponential_decay 使学习率自动衰减。
效果和上一篇博客几乎一样,可以参考一下。
网络图如下:
代码
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
os.system("rm -r logs")
import tensorflow as tf
import matplotlib.pyplot as plt
# In[ ]:
trainPath = '/home/winsoul/disk/Segmentation/SiftFlow/data/GeoLabels/tfrecords/train.tfrecords'
testPath = '/home/winsoul/disk/Segmentation/SiftFlow/data/GeoLabels/tfrecords/test.tfrecords'
valPath = '/home/winsoul/disk/Segmentation/SiftFlow/data/GeoLabels/tfrecords/val.tfrecords'
model_path = '/home/winsoul/disk/Segmentation/SiftFlow/FCN_modelUseBN/model_backup/'
DisplayStep = 25
ModelSaverStep = 2500
decay_step = 500
decay_rate = 0.995
# In[ ]:
def read_tfrecords(TFRecordsPath):
with tf.Session() as sess:
feature = {
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.string),
'name': tf.FixedLenFeature([], tf.string),
}
filename_queue = tf.train.string_input_producer([TFRecordsPath])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features = feature)
image = tf.decode_raw(features['image'], tf.float32)
image = tf.reshape(image, [224, 224, 3])
label = tf.decode_raw(features['label'], tf.uint8)
label = tf.reshape(label, [224, 224])
return image, label
# In[ ]:
def conv_layer(X, k, s, channels_in, channels_out, name = 'CONV', padding = 'SAME', is_training = True):
with tf.name_scope(name):
W = tf.Variable(tf.truncated_normal([k, k, channels_in, channels_out], stddev = 0.1));
b = tf.Variable(tf.constant(0.1, shape = [channels_out]))
conv = tf.nn.conv2d(X, W, strides = [1, s, s, 1], padding = padding)
bn = tf.layers.batch_normalization(conv + b, training = is_training)
result = tf.nn.relu(bn)
# tf.summary.histogram('weights', W)
# tf.summary.histogram('biases', b)
# tf.summary.histogram('activations', result)
return result
def pool_layer(X, k, s, strr = 'SAME', pool_type = 'MAX', name = 'pool'):
with tf.name_scope(name):
if pool_type == 'MAX':
result = tf.nn.max_pool(X,
ksize = [1, k, k, 1],
strides = [1, s, s, 1],
padding = strr, name = name)
else:
result = tf.nn.avg_pool(X,
ksize = [1, k, k, 1],
strides = [1, s, s, 1],
padding = strr, name = name)
return result
def fc_layer(X, neurons_in, neurons_out, last = False, name = 'FC'):
with tf.name_scope(name):
W = tf.Variable(tf.truncated_normal([neurons_in, neurons_out], stddev = 0.1))
b = tf.Variable(tf.constant(0.1, shape = [neurons_out]))
# tf.summary.histogram('weights', W)
# tf.summary.histogram('biases', b)
if last == False:
result = tf.nn.relu(tf.matmul(X, W) + b)
else:
result = tf.nn.softmax(tf.matmul(X, W) + b)
# tf.summary.histogram('activations', result)
return result
# In[ ]:
def conv_transpose_layer(X, k, s, input_shape, output_shape, name = 'CONV_TRAN', padding = 'SAME'):
with tf.name_scope(name):
W = tf.Variable(tf.truncated_normal([k, k, output_shape[3].value, input_shape[3].value], stddev = 0.1));
b = tf.Variable(tf.constant(0.1, shape = [output_shape[3].value]))
deconv = tf.nn.conv2d_transpose(X, W, output_shape, strides=[1, s, s, 1], padding = "SAME")
result = tf.add(deconv, b)
result = deconv
# tf.summary.histogram('weights', W)
# tf.summary.histogram('biases', b)
# tf.summary.histogram('activations', result)
return result
# In[ ]:
def Network(BatchSize, start_learning_rate):
tf.reset_default_graph()
with tf.Session() as sess:
is_training = tf.placeholder(dtype = tf.bool)
keep_prob = tf.placeholder(dtype = tf.float32)
global_step = tf.placeholder(dtype = tf.int32)
origin_image = tf.placeholder(tf.uint8, shape=([BatchSize, 224, 224, 4]))
y_label = tf.placeholder(tf.int32, shape=[None, 224, 224, 1], name="y_label")
image_train, label_train = read_tfrecords(trainPath)
image_val, label_val = read_tfrecords(valPath)
image_train_batch, label_train_batch = tf.train.shuffle_batch([image_train, label_train],
batch_size = BatchSize,
capacity = BatchSize * 3 + 200,
min_after_dequeue = BatchSize)
image_val_batch, label_val_batch = tf.train.shuffle_batch([image_val, label_val],
batch_size = BatchSize,
capacity = BatchSize * 3 + 200,
min_after_dequeue = BatchSize)
image_Batch = tf.cond(is_training, lambda: image_train_batch, lambda: image_val_batch)
label_Batch = tf.cond(is_training, lambda: label_train_batch, lambda: label_val_batch)
X = tf.identity(image_Batch)
y = tf.identity(label_Batch)
y = tf.cast(y, tf.int32)
conv1_1 = conv_layer(X, 3, 1, 3, 64, "conv1_1")
conv1_2 = conv_layer(conv1_1, 3, 1, 64, 64, "conv1_2")
pool1 = pool_layer(conv1_2, 2, 2, "SAME", "MAX", "pool1")
conv2_1 = conv_layer(pool1, 3, 1, 64, 128, "conv2_1")
conv2_2 = conv_layer(conv2_1, 3, 1, 128, 128, "conv2_2")
pool2 = pool_layer(conv2_2, 2, 2, "SAME", "MAX", 'pool2')
conv3_1 = conv_layer(pool2, 3, 1, 128, 256, "conv3_1")
conv3_2 = conv_layer(conv3_1, 3, 1, 256, 256, "conv3_2")
conv3_3 = conv_layer(conv3_2, 3, 1, 256, 256, "conv3_3")
pool3 = pool_layer(conv3_3, 2, 2, "SAME", "MAX", 'pool3')
conv4_1 = conv_layer(pool3, 3, 1, 256, 512, "conv4_1")
conv4_2 = conv_layer(conv4_1, 3, 1, 512, 512, "conv4_2")
conv4_3 = conv_layer(conv4_2, 3, 1, 512, 512, "conv4_3")
pool4 = pool_layer(conv4_3, 2, 2, "SAME", "MAX", 'pool4')
print(pool4)
conv5_1 = conv_layer(pool4, 3, 1, 512, 512, "conv5_1")
conv5_2 = conv_layer(conv5_1, 3, 1, 512, 512, "conv5_2")
conv5_3 = conv_layer(conv5_2, 3, 1, 512, 512, "conv5_3")
pool5 = pool_layer(conv5_3, 2, 2, "SAME", "MAX", 'pool5')
print(pool5)
conv6_1 = conv_layer(pool5, 7, 1, 512, 1024, "conv6_1")
conv6_2 = conv_layer(conv6_1, 1, 1, 1024, 512, "conv6_2")
conv6_3 = conv_layer(conv6_2, 1, 1, 512, 4, "conv6_3")
drop1 = tf.nn.dropout(conv6_3, keep_prob)
print(drop1)
deconv1 = conv_transpose_layer(drop1, 4, 2, conv6_3.get_shape(), pool4.get_shape(), name = 'CONV_TRAN_1')
fuse1 = tf.add(deconv1, pool4, name = 'fuse_1')
print(fuse1)
deconv2 = conv_transpose_layer(fuse1, 4, 2, pool4.get_shape(), pool3.get_shape(), name = 'CONV_TRAN_2')
fuse2 = tf.add(deconv2, pool3, name = 'fuse_2')
print(fuse2)
deconv3 = conv_transpose_layer(fuse2, 16, 8, pool3.get_shape(), origin_image.get_shape(), name = 'CONV_TRAN_3')
print(deconv3)
y_result = tf.argmax(deconv3, dimension = 3, name = 'y_result')
print(y_result)
y_result = tf.cast(y_result, tf.int32)
with tf.name_scope('input'):
tf.summary.image('input', X, BatchSize)
with tf.name_scope('output'):
y_paint = tf.cast(y_result, tf.uint8)
y_paint = (y_paint * (y_paint + 6)) * 9
tf.summary.image('output', tf.reshape(y_paint, [-1, 224, 224, 1]), BatchSize)
with tf.name_scope('summaries'):
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
cross_entropy = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits = deconv3,
labels = y,
name = "cross_entropy")))
learning_rate = tf.train.exponential_decay(start_learning_rate, global_step, decay_step, decay_rate, staircase = True)
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
#train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
corrent_prediction = tf.equal(y_result, y)
accuracy = tf.reduce_mean(tf.cast(corrent_prediction, 'float', name = 'accuracy'))
tf.summary.scalar("loss", cross_entropy)
tf.summary.scalar("accuracy", accuracy)
with tf.name_scope('learning_rate'):
tf.summary.scalar("learning_rate", learning_rate)
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord = coord)
merge_summary = tf.summary.merge_all()
summary__train_writer = tf.summary.FileWriter("./logs/train" + '_rate:' + str(start_learning_rate), sess.graph)
summary_val_writer = tf.summary.FileWriter("./logs/test" + '_rate:' + str(start_learning_rate))
saver = tf.train.Saver()
# saver.restore(sess, model_path + 'Model_rate_1e-5__Step_00020000')
try:
batch_index = 1
while not coord.should_stop():
sess.run(train_step, feed_dict = {is_training: True, keep_prob: 0.5, global_step: batch_index})
if batch_index % 25 == 0:
summary_train, acc_train, loss_train, _ = sess.run([merge_summary, accuracy, cross_entropy, train_step], feed_dict = {is_training: True, keep_prob: 0.5, global_step: batch_index})
summary__train_writer.add_summary(summary_train, batch_index)
print(str(batch_index) + ' train:' + ' ' + str(acc_train) + ' ' + str(loss_train), end = ' ')
summary_val, acc_val, loss_val = sess.run([merge_summary, accuracy, cross_entropy], feed_dict = {is_training: False, keep_prob: 1.0, global_step: batch_index})
summary_val_writer.add_summary(summary_val, batch_index)
print(' val: ' + ' ' + str(acc_val) + ' ' + str(loss_val))
if batch_index % ModelSaverStep == 0:
save_path = saver.save(sess, model_path + '/newModel/Model_rate_1e-5__Step_{:08d}'.format(batch_index))
batch_index += 1;
# if batch_index > 1000:
# break;
# for i in range(BatchSize):
# plt.imshow(ans[0], cmap = 'gray')
# plt.show()
except tf.errors.OutOfRangeError:
print("OutofRangeError!")
coord.request_stop()
coord.join(threads)
sess.close()
# In[ ]:
def main():
rate = 1e-5
while True:
print('-----------------------------------------------------')
print('Batch: 16 learning_rate:', rate)
try:
Network(16, rate)
except KeyboardInterrupt:
pass
rate /= 3
if __name__ == '__main__':
main()