【TFRecord】Tensorflow默认标准数据格式

Tensorflow默认标准数据格式TFRecord学习

 简介

       在我们的工程项目中,需要训练的数据集通常会以不同的格式出现,有时候甚至是混合模式。为了方便我们使用,将文件不管其原始格式,都转换成统一的格式是很好的方法。
       
       Tensorflow默认的标准数据格式是TFRecord。TFRecord文件就是一个简单的二进制文件,包括序列化的输入世俗据。序列化是基于协议缓冲区( protobuf )的,会使用一种独立于任意平台或者语言的描述数据结构的机制转换数据,供存储使用。
       
       在我们的设置下,使用TFRecord相比直接采用原始数据文件有很多优势。这个统一的格式能够给出一种简洁的方法去组织输入数据,使一个输入实例的所有相关属性存放在一起,避免大量文件目录的泛滥。
       
       TFRecord文件可以有非常快的处理速度。所有的数据被存放在内存的一块中。而不是分割的每个文件,将原本需要从内存中读取数据的时间节省下来。
       
       除此之外,Tensorflow提供了有关TFRecord的很多实现及优化工具,使得TFRecord可以用座多线程输入管道的一部分。

 写入数据

       首先,我们将一个输入文件写为TFRecord格式,这样就可以对这些文件进行处理。在下面的示例,我们将把MNIST图像转换为该格式,类似的方法也可以用在其他类型的数据上。

       首先下载数据集,导入相关的包:

	from __future__ import print_function
	import os
	import tensorflow as tf
	from tensorflow.contrib.learn.python.learn.datasets import mnist
	import numpy as np
	
	save_dir = 'c:/tmp/data'
	
	# Download data to save_dir
	data_sets = mnist.read_data_sets(save_dir, dtype=tf.uint8, reshape=False, validation_size=1000)

       下载完的数据包括训练、测试和验证图像,每个都是在一个分隔中。遍历每个分隔,将样本转化为合适的格式并使用 TFRecordWriter() 写到磁盘上

	data_splits = ["train", "test", "validation"]
	for d in range(len(data_splits)):
	    print("saving" + data_splits[d])
	    data_set = data_sets[d]
	
	    # 实例化TFRecordWriter对象
	    filename = os.path.join(save_dir, data_splits[d] + '.tfrecords')
	    writer = tf.python_io.TFRecordWriter(filename)
	
	    for index in range(data_set.images.shape[0]):
	        # 遍历图像,将其从numpy数组转换成一个字节字符串
	        image = data_set.images[index].tostring()
	        # 存放数据,一个Example对象包含一个Fetures对象,Fetures对象包含从属性名到一个Feture的映射
	        # 特征可以包含Int64List、BytesList或Floatlist
	        example = tf.train.Example(features=tf.train.Features(feature={
    
    
	            'height' : tf.train.Feature(int64_list=tf.train.Int64List(value=[data_set.images.shape[1]])),
	            'width' : tf.train.Feature(int64_list=tf.train.Int64List(value=[data_set.images.shape[2]])),
	            'depth' : tf.train.Feature(int64_list=tf.train.Int64List(value=[data_set.images.shape[3]])),
	            'label' : tf.train.Feature(int64_list=tf.train.Int64List(value=[int(data_set.labels[index])])),
	            'image_raw' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
	        }))
	
	        writer.write(example.SerializeToString())
	    writer.close()

       

 读取数据

       想要读取保存的数据,这里可以使用tf.python_io.tf_recond_iterator 迭代器从一个 TFRecord 文件中读取记录:

		filename = os.path.join(save_dir, 'train.tfrecords')
		record_iterator = tf.python_io.tf_record_iterator(filename)
		seralized_img_example = next(record_iterator)

       
       为了恢复当保存图像到TFRecord时使用需要的结构,我们对这个字节字符串进行解析,就能够获取之前存放的所有属性:

		example = tf.train.Example()
		example.ParseFromString(seralized_img_example)
		image = example.features.feature['image_raw'].bytes_list.value
		label = example.features.feature['label'].int64_list.value[0]
		width = example.features.feature['width'].int64_list.value[0]
		height = example.features.feature['height'].int64_list.value[0]

       我们图像也存放为一个字节字符串,所以我们可以将其转换还原成一个Numpy数组,将其重新变形( 28, 28, 1 )

		img_flat = np.fromstring(image[0], dtype=np.uint8)
		img_reshaped = img_flat.reshape((height, width, -1))

       
       这个基本的例子展示了如何使用TFRecord以及如何读写他们的方法。实践中,一般会将TFRecord读到一个预获取的队列中作为多线程过程的一部分。

       

       本文示例参考《TensorFlow学习指南——深度学习系统构建详解》第八章第二节。

       欢迎各位大佬交流讨论!

猜你喜欢

转载自blog.csdn.net/weixin_42721167/article/details/112750572