tensorflow tfrecord文件生成,网络输入管道
标签(空格分隔): tensorflow 源码
在医学图像中,不像自然图像那样是规整的3通道8位数据,不同的医学影像有不同的医学存储格式,以本小硕的课题来说,医学图像数据类型为为float32。之前为了保证数据的原始性,一直不敢存储为png、bmp那样的数据格式,而是存储为numpy的npz格式。
但是,对于tensorflow来说,如果采用npz存储的话,需要一次性将数据全部读入内存,这样一是读取速度特别慢;而是浪费内存。最终,本小硕还是试图转成tfrecord标准文件,采用tensorflow自带的数据流图。
转换代码:
import os
import sys
import numpy as np
import math
import tensorflow as tf
#import build_data
def covert_bin2tfrecord(data_dir,num_shards,save_path):
#读取原始数据
X=np.load(os.path.join(data_dir,'data.npy'))
Y=np.load(os.path.join(data_dir,'label.npy'))
num_slices=X.shape[0]
num_per_shard=int(math.ceil(num_slices/float(num_shards)))
for shard_id in xrange(num_shards):
output_filename=os.path.join(save_path,'%s-%05d-of-%05d.tfrecord' %(data_dir.split('/')[-1],shard_id,num_shards))
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_idx=shard_id * num_per_shard
end_idx = min((shard_id+1)*num_per_shard,num_slices)
for i in xrange(start_idx,end_idx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i + 1, num_slices, shard_id))
sys.stdout.flush()
height,width = X.shape[2],X.shape[3]
image_data = tf.compat.as_bytes(X[i,...].tostring())
gt_data = tf.compat.as_bytes(Y[i,...].tostring())
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
'image/channels': tf.train.Feature(int64_list=tf.train.Int64List(value=[4])),
'image/segmentation/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[gt_data])),
'image/segmentation/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=b'png')),
}))
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
if __name__=='__main__':
covert_bin2tfrecord('/',5,'/') #训练集
covert_bin2tfrecord('/', 1, '/') #测试集
保存为tfrecord文件后,为了以防万一,我们还是要可视化一下数据是否改变:
import tensorflow as tf
import numpy as np
from skimage import io
#from skimage import io
from glob import glob
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def read_tfrecord(tfrecords_filename):
if not isinstance(tfrecords_filename, list):
tfrecords_filename = [tfrecords_filename]
filename_queue = tf.train.string_input_producer(
tfrecords_filename)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/segmentation/encoded': tf.FixedLenFeature([], tf.string),
})
image =tf.decode_raw(features['image/encoded'],tf.float32)
gt_mask =tf.decode_raw(features['image/segmentation/encoded'],tf.uint8)
image=tf.reshape(image,[6,320,320])
return image, gt_mask
if __name__=='__main__':
files=glob('/train*')
with tf.Session() as sess:
#image,gt=read_tfrecord(files)
#建立文件流图
filename_queue = tf.train.string_input_producer(files)
#建立读取队列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/segmentation/encoded': tf.FixedLenFeature([], tf.string)
})
# image = tf.decode_raw(features['image/encoded'], tf.float32)
#进行格式转换 将 tf.string 转化成 tf.uint8 和 tf.float32
image = tf.decode_raw(features['image/encoded'],tf.float32)
image = tf.reshape(image,(6,320,320))
gt_mask = tf.decode_raw(features['image/segmentation/encoded'],tf.uint8)
gt_mask = tf.reshape(gt_mask,(320,320))
#读取队列图
image_batch,gt_batch = tf.train.shuffle_batch([image,gt_mask], batch_size=256,capacity=30, min_after_dequeue=20, num_threads=1)
#init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
# 初始化图的全局和局部变量
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
# 线程管理
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# tf.train.start_queue_runners(sess=sess)
ib,gb=sess.run([image_batch,gt_batch])
print(ib.shape)
print(gb.shape)
data=np.zeros((81920,1280),dtype=np.float32)
for i in xrange(64):
for j in xrange(4):
data[i*320:(i+1)*320,j*320:(j+1)*320]=ib[i,j,...]
#可视化
io.imsave('vis_tfrecord.png',data)
coord.request_stop()
coord.join(threads)
# data=np.concatenate([i,g],axis=2)
例子就不方便展示了