tensorflow中tfrecords使用介绍

  这篇文章主要讲一下如何用Tensorflow中的标准数据读取方式简单的实现对自己数据的读取操作.

主要分为以下两个步骤:(1)将自己的数据集转化为 xx.tfrecords的形式;(2):在自己的程序中读取并使用.tfrecords进行操作.

数据集转换:为了便于讲解,我们简单制作了一个数据,如下图所示:


程序:

[python]  view plain  copy
  1. import tensorflow as tf  
  2. import numpy as np  
  3. import os  
  4. from PIL import Image  
  5. def _int64_feature(value):  
  6.   return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))  
  7.   
  8. def _bytes_feature(value):  
  9.   return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  
  10.   
  11.   
  12. def img_to_tfrecord(data_path):  
  13.     rows = 256  
  14.     cols = 256  
  15.     depth = 3  
  16.     writer = tf.python_io.TFRecordWriter('test.tfrecords')  
  17.     labelfile=open("random.txt")  
  18.     lines=labelfile.readlines()  
  19.     for line in lines:  
  20.         #print line  
  21.         img_name = line.split(" ")[0]#name  
  22.         label = line.split(" ")[1]#label  
  23.         img_path = data_path+img_name  
  24.         img = Image.open(img_path)  
  25.         img = img.resize((rows,cols))  
  26.         #img_raw = img.tostring()      
  27.         img_raw = img.tobytes()   
  28.         example = tf.train.Example(features = tf.train.Features(feature = {  
  29.                             'height': _int64_feature(rows),  
  30.                            'weight': _int64_feature(cols),  
  31.                             'depth': _int64_feature(depth),  
  32.                         'image_raw': _bytes_feature(img_raw),  
  33.                 'label': _bytes_feature(label)}))  
  34.                   
  35.             writer.write(example.SerializeToString())      
  36.     writer.close()   
  37.   
  38.   
  39.   
  40. if __name__ == '__main__':  
  41.     current_dir = os.getcwd()      
  42.     data_path = current_dir + '/data/'      
  43.     #name = current_dir + '/data'  
  44.     print('Convert start')     
  45.     img_to_tfrecord(data_path)  
  46.     print('done!')  

运行该段程序可以看到在dataset_tfrecord文件夹下面有test.tfrecord文件生成。
在TF的Session中调用这个生成的文件

[python]  view plain  copy
  1. #encoding=utf-8   
  2. # 设置utf-8编码,方便在程序中加入中文注释.  
  3. import os  
  4. import scipy.misc  
  5. import tensorflow as tf  
  6. import numpy as np  
  7. from test import *  
  8. import matplotlib.pyplot as plt  
  9.   
  10. def read_and_decode(filename_queue):  
  11.           
  12.     reader = tf.TFRecordReader()  
  13.     _, serialized_example = reader.read(filename_queue)  
  14.       
  15.     features = tf.parse_single_example(serialized_example,features = {  
  16.                         'image_raw':tf.FixedLenFeature([], tf.string)})  
  17.     image = tf.decode_raw(features['image_raw'], tf.uint8)  
  18.     image = tf.reshape(image, [OUTPUT_SIZE, OUTPUT_SIZE, 3])  
  19.     image = tf.cast(image, tf.float32)  
  20.     #image = image / 255.0  
  21.       
  22.     return image  
  23.   
  24. data_dir = '/home/sanyuan/dataset_animal/dataset_tfrecords/'   
  25.   
  26. filenames = [os.path.join(data_dir,'train%d.tfrecords' % ii) for ii in range(1)] #如果有多个文件,直接更改这里即可  
  27. filename_queue = tf.train.string_input_producer(filenames)  
  28. image = read_and_decode(filename_queue)  
  29. with tf.Session() as sess:      
  30.     coord = tf.train.Coordinator()  
  31.     threads = tf.train.start_queue_runners(coord=coord)  
  32.     for i in xrange(2):  
  33.         img = sess.run([image])  
  34.         print(img[0].shape)  # 设置batch_size等于1.每次读出来只有一张图  
  35.         plt.imshow(img[0])  
  36.         plt.show()  
  37.     coord.request_stop()  
  38.     coord.join(threads)  
  39.       

程序到这里就已经处理完成了,当然在decorde的过程中也是可以进行一些预处理操作的,不过建议还是在制作数据集的时候进行,TFrecord使用的是队列的方式进行读取数据,这个对于多线程操作来说还是很方便的,只需要设置好格式,每次直接读取就可以了.

猜你喜欢

转载自blog.csdn.net/yinxingtianxia/article/details/78236400