tensorflow tfrecord文件生成,网络输入管道

tensorflow tfrecord文件生成,网络输入管道

标签(空格分隔): tensorflow 源码


在医学图像中,不像自然图像那样是规整的3通道8位数据,不同的医学影像有不同的医学存储格式,以本小硕的课题来说,医学图像数据类型为为float32。之前为了保证数据的原始性,一直不敢存储为png、bmp那样的数据格式,而是存储为numpy的npz格式。

但是,对于tensorflow来说,如果采用npz存储的话,需要一次性将数据全部读入内存,这样一是读取速度特别慢;而是浪费内存。最终,本小硕还是试图转成tfrecord标准文件,采用tensorflow自带的数据流图。

转换代码:

import os
import sys
import numpy as np
import math
import tensorflow as tf
#import build_data

def covert_bin2tfrecord(data_dir,num_shards,save_path):

    #读取原始数据
    X=np.load(os.path.join(data_dir,'data.npy'))
    Y=np.load(os.path.join(data_dir,'label.npy'))

    num_slices=X.shape[0]
    num_per_shard=int(math.ceil(num_slices/float(num_shards)))
    for shard_id in xrange(num_shards):
        output_filename=os.path.join(save_path,'%s-%05d-of-%05d.tfrecord' %(data_dir.split('/')[-1],shard_id,num_shards))
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
            start_idx=shard_id * num_per_shard
            end_idx = min((shard_id+1)*num_per_shard,num_slices)
            for i in xrange(start_idx,end_idx):
                sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
                    i + 1, num_slices, shard_id))
                sys.stdout.flush()
                height,width = X.shape[2],X.shape[3]
                image_data = tf.compat.as_bytes(X[i,...].tostring())
                gt_data = tf.compat.as_bytes(Y[i,...].tostring())
                example = tf.train.Example(features=tf.train.Features(feature={
                    'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
                    'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
                    'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
                    'image/channels': tf.train.Feature(int64_list=tf.train.Int64List(value=[4])),
                    'image/segmentation/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[gt_data])),
                    'image/segmentation/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=b'png')),
                }))
                tfrecord_writer.write(example.SerializeToString())
            sys.stdout.write('\n')
            sys.stdout.flush()

if __name__=='__main__':
    covert_bin2tfrecord('/',5,'/') #训练集
    covert_bin2tfrecord('/', 1, '/') #测试集

保存为tfrecord文件后,为了以防万一,我们还是要可视化一下数据是否改变:

import tensorflow as tf
import numpy as np
from skimage import io
#from skimage import io
from glob import glob
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def read_tfrecord(tfrecords_filename):
    if not isinstance(tfrecords_filename, list):
        tfrecords_filename = [tfrecords_filename]
    filename_queue = tf.train.string_input_producer(
        tfrecords_filename)

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.FixedLenFeature([], tf.string),
            'image/segmentation/encoded': tf.FixedLenFeature([], tf.string),
        })
    image =tf.decode_raw(features['image/encoded'],tf.float32)
    gt_mask =tf.decode_raw(features['image/segmentation/encoded'],tf.uint8)
    image=tf.reshape(image,[6,320,320])
    return image, gt_mask


if __name__=='__main__':
    files=glob('/train*')
    with tf.Session() as sess:
        #image,gt=read_tfrecord(files)
        #建立文件流图
        filename_queue = tf.train.string_input_producer(files)

        #建立读取队列
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(
            serialized_example,
            features={
                'image/encoded': tf.FixedLenFeature([], tf.string),
                'image/segmentation/encoded': tf.FixedLenFeature([], tf.string)
            })
        # image = tf.decode_raw(features['image/encoded'], tf.float32)

        #进行格式转换 将 tf.string 转化成 tf.uint8 和 tf.float32
        image = tf.decode_raw(features['image/encoded'],tf.float32)
        image = tf.reshape(image,(6,320,320))
        gt_mask = tf.decode_raw(features['image/segmentation/encoded'],tf.uint8)
        gt_mask = tf.reshape(gt_mask,(320,320))
        #读取队列图
        image_batch,gt_batch  = tf.train.shuffle_batch([image,gt_mask], batch_size=256,capacity=30, min_after_dequeue=20, num_threads=1)
        #init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        # 初始化图的全局和局部变量
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())

        # 线程管理
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
    # tf.train.start_queue_runners(sess=sess)
        ib,gb=sess.run([image_batch,gt_batch])
        print(ib.shape)
        print(gb.shape)
        data=np.zeros((81920,1280),dtype=np.float32)
        for i in xrange(64):
            for j in xrange(4):
                data[i*320:(i+1)*320,j*320:(j+1)*320]=ib[i,j,...]

        #可视化
        io.imsave('vis_tfrecord.png',data)
        coord.request_stop()
        coord.join(threads)
        # data=np.concatenate([i,g],axis=2)

例子就不方便展示了

猜你喜欢

转载自blog.csdn.net/charel_chen/article/details/80365491