Slim读取TFrecord文件

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/MOU_IT/article/details/82773839

 1、TFrecord文件的格式定义

    TFrecord文件介绍请看这里,使用Slim中的高级API来读取TFrecord文件和普通的读取方式还是有区别的,参考Slim中的实例flower,以及它的TFrecord的格式定义:    

def int64_feature(values):
  if not isinstance(values, (tuple, list)):
    values = [values]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

def float_feature(values):
  if not isinstance(values, (tuple, list)):
    values = [values]
  return tf.train.Feature(float_list=tf.train.FloatList(value=values))

def image_to_tfexample(image_data, image_format, height, width, class_id):
  return tf.train.Example(features=tf.train.Features(feature={
      'image/encoded': bytes_feature(image_data),
      'image/format': bytes_feature(image_format),
      'image/class/label': int64_feature(class_id),
      'image/height': int64_feature(height),
      'image/width': int64_feature(width),
  }))

   这里要注意的是,TFrecord文件的格式定义中,一定要包含“image/encoded”和“image/format”两个关键字 ,第一个关键字的值为图像的二进制值,第二个为图像的格式。

2、使用Slim读取TFrecord文件的步骤:

   参考以上实例,以及阅读Slim.data的源码,我总结出使用Slim读取TFrecord文件的详细步骤如下:

    (1)设置解码器,一般设置为decoder=slim.tfexample_decoder.TFExampleDecoder(),同时要指定其keys_to_features,和items_to_handlers两个字典参数。key_to_features这个字典需要和TFrecord文件中定义的字典项匹配。items_to_handlers中的关键字可以是任意值,但是它的handler的初始化参数必须要来自于keys_to_features中的关键字。

    (2)定义数据集类,一般为dataset=slim.dataset.Dataset():它把datasource、reader、decoder、num_samples等参数封装好。

    (3)定义数据集的数据提供者类,一般为provider=slim.dataset_data_provider.DatasetDataProvider(),需要传入的参数:dataset, num_readers, reader_kwargs, shuffle, num_epochs,common_queue_capacity,common_queue_min, record_key=',seed, scope等。在这个类中:

        1)首先调用_,data=parallel_reader.parallel_read(),这个方法调用tf.train.string_input_producer()得到TFrecord的文件队列(filename_queue),然后根据是否shuffle生成一个公共队列(common queue),用reader_class,common_queue,num_readers,reader_kwargs=reader_kwargs等参数初始化ParallelReader(),然后调用它的read(filename_queuq)方法,这个read()方法先用reader从filename_queue中读取数据然后enqueue到common queue中,然后从common queue中dequeue,从而得到(filename,data)的键值对。

        2)调用items=dataset.decoder.list_items()得到decoder中的items_to_handlers的关键字列表items。

       3)根据1)和2)得到的data和items,调用tensors=dataset.decoder.decode(data, items)。这解码过程中,首先调用example=parsing_ops.parse_single_example(data,keys_to_features)来解析序列化数据得到一个字典特征,然后根据items_to_handlers中传给handler的那些items(这些items来自keys_to_features中的keys),将example中的字典中属于某个handler的多个键值对(因为一个handler用多个items初始化,所以一个handler对应example中多个键值对)交给相应的handler处理,然后每个handler处理完成后返回一个tensor,将所有tensor组成一个列表tensors。

        4)然后将2)中得到的items和3)中得到的tensors进行匹配生成一个字典items_to_tensors。

  (4)调用provider的get方法从items_to_tensors中获取响应的items对应的tensor,比如[image, label] = provider.get(['image', 'label'])

扫描二维码关注公众号,回复: 3310796 查看本文章

3、实例

   这里我的图片放在D:/test/目录下,有0-9共10张图片。

#coding=utf-8
import tensorflow as tf
import numpy as np
import os 
from PIL import Image

slim = tf.contrib.slim 

# 创建TFrecord文件
def create_record_file():
    train_filename = "train.tfrecords"
    if os.path.exists(train_filename):
        os.remove(train_filename)
 
  # 创建.tfrecord文件,准备写入
    writer = tf.python_io.TFRecordWriter('./'+train_filename)
    with tf.Session() as sess:
      for i in range(10):  
          img_raw = tf.gfile.FastGFile("D:/test/"+str(i)+".jpg", 'rb').read()
          decode_data = tf.image.decode_jpeg(img_raw)
          image_shape= decode_data.eval().shape
          example = tf.train.Example(features=tf.train.Features(
                  feature={
                  'image/encoded':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                  'image/format':tf.train.Feature(bytes_list = tf.train.BytesList(value=[b'jpg'])),
                  'image/width':tf.train.Feature(int64_list = tf.train.Int64List(value=[image_shape[1]])), 
                  'image/height':tf.train.Feature(int64_list = tf.train.Int64List(value=[image_shape[0]])),
                  'image/label':tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),               
                  }))
          writer.write(example.SerializeToString())  # 序列化保存
      writer.close()
      print ("保存tfrecord文件成功。")

# 使用Slim的方法从TFrecord文件中读取
def read_record_file():    
    tfrecords_filename = "train.tfrecords"  
    # 将tf.train.Example反序列化成存储之前的格式。由tf完成
    keys_to_features = {
          'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
          'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
          'image/width': tf.FixedLenFeature((), tf.int64, default_value=0),
          'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
          'image/label': tf.FixedLenFeature((), tf.int64, default_value=0),
      }
    # 将反序列化的数据组装成更高级的格式。由slim完成
    items_to_handlers = {
          'image': slim.tfexample_decoder.Image(image_key='image/encoded',
                                                format_key='image/format',
                                                channels=3),
          'label': slim.tfexample_decoder.Tensor('image/label'),
          'height': slim.tfexample_decoder.Tensor('image/height'),
          'width': slim.tfexample_decoder.Tensor('image/width')
      }
    # 定义解码器,进行解码
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    # 定义dataset,该对象定义了数据集的文件位置,解码方式等元信息
    dataset = slim.dataset.Dataset(
          data_sources=tfrecords_filename,
          reader=tf.TFRecordReader,
          decoder=decoder,
          num_samples=10,        # 训练数据的总数
          items_to_descriptions=None,
          num_classes=10,
          )
    #使用provider对象根据dataset信息读取数据
    provider = slim.dataset_data_provider.DatasetDataProvider(
              dataset,
              num_readers=1,
              common_queue_capacity=20,
              common_queue_min=1)
     
     # 获取数据
    [image, label,height,width] = provider.get(['image', 'label','height','width'])    
    with tf.Session() as sess:
      init_op = tf.global_variables_initializer()
      sess.run(init_op)
      coord=tf.train.Coordinator()
      threads= tf.train.start_queue_runners(coord=coord)
      for i in range(10):
        img,l,h,w= sess.run([image,label,height,width])        
        img = tf.reshape(img, [h,w,3]) 
        print (img.shape)       
        img=Image.fromarray(img.eval(), 'RGB')       # 这里将narray转为Image类,Image转narray:a=np.array(img)
        img.save('./'+str(l)+'.jpg')                 # 保存图片

      coord.request_stop()
      coord.join(threads)   

if __name__ == '__main__':
    create_record_file()
    read_record_file()

猜你喜欢

转载自blog.csdn.net/MOU_IT/article/details/82773839