注意:
- 我自己是用来训练医学图像生成新的医学图像的,所以我的label都是1(或者其他数值),因为在GAN训练的时候完全用不到,只有分类的问题才能用到。
- 如果训练分类问题下文中的代码就没用了,不过可以参考这个链接。
- 我的图片是黑白的,所以channel为1,并且图片保存的格式为png的,所以大家理解的时候请注意
image=tf.image.decode_png(image_contents,channels=1)
。
- 代码的输入是png图片,通道数为1,输出为batch_size的
tensor
(shape=[batch_size, image_W, image_H, 1]),但是绝对不能用到feed_dict里面,因为feed_dict只接收numpy格式的数据格式
,所以在接下来的博客中,本博主要在上面下功夫!
tensorflow中 tf.train.slice_input_producer
函数可以参考这个链接,和这个链接讲的非常详细。
- 初始化工作相当重要
输出其中一个样例:
"""
Created on Thu Jul 19 15:40:11 2018
E-mail: [email protected]
@author: DidiLv
"""
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
file_dir='D:\\CT_data\\Data_preprocessing\\'
def get_files(file_dir):
lung_img = [];
label_lung_img = [];
for file in os.listdir(file_dir):
lung_img.append( file_dir + file)
label_lung_img.append(1)
image_list = np.hstack((lung_img))
label_list = np.hstack((label_lung_img))
temp = np.array([lung_img, label_lung_img]).T
np.random.shuffle(temp)
image_list = list(temp[:,0])
label_list = list(temp[:,1])
label_list = [int(i) for i in label_list]
return image_list, label_list
def get_batch(image,label,batch_size):
image_W, image_H = 221, 181
image=tf.cast(image,tf.string)
label=tf.cast(label,tf.int32)
epoch_num = 50
input_queue=tf.train.slice_input_producer([image,label], num_epochs=epoch_num)
label=input_queue[1]
image_contents=tf.read_file(input_queue[0])
image=tf.image.decode_png(image_contents,channels=1)
image=tf.image.resize_image_with_crop_or_pad(image,image_W,image_H)
image=tf.image.per_image_standardization(image)
min_after_dequeue=1000
capacity=min_after_dequeue+3*batch_size
image_batch,label_batch=tf.train.shuffle_batch([image,label],batch_size=batch_size,num_threads=64,capacity=capacity,min_after_dequeue=min_after_dequeue)
image_batch = tf.reshape(image_batch,[batch_size,image_W,image_H,1])
image_batch=tf.cast(image_batch,np.float32)
return image_batch, label_batch
if __name__ == "__main__":
image_list, label_list = get_files(file_dir)
image_batch, label_batch = get_batch(image_list, label_list, 64)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
i = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop() and i<1:
img, label = sess.run([image_batch, label_batch])
for j in np.arange(64):
print('label: %d' %label[j])
plt.imshow(img[j,:,:,0])
plt.show()
i+=1
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
print('-----------')
coord.join(threads)