TFRecords

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011956147/article/details/79290163

TFRecords是tensorflow内定标准格式,类似caffe的lmdb,好处是做好数据后使用方便,速度快,但是占用磁盘空间,在很多情况下,直接读取数据处理也是一种比较好的方法,比如做检测,以后会补充多线程直接进行数据读取。这里记录下使用TFRecords的方法,仅作为个人备忘录


1,生成TFRecords

使用tf.train.Example把、定义要写入protobuf buffer数据的格式,使用tf.python_io.TFRecordWriter写入,更多的关于tf.train.Example参考:https://www.tensorflow.org/api_docs/python/tf/train/Example

def _convert_to_example_simple(image_example, image_buffer):
    """
    covert to tfrecord file
    :param image_example: dict, an image example
    :param image_buffer: string, JPEG encoding of RGB image
    :param colorspace:
    :param channels:
    :param image_format:
    :return:
    Example proto
    """
    class_label = image_example['label']
    bbox = image_example['bbox']
    roi = [bbox['xmin'],bbox['ymin'],bbox['xmax'],bbox['ymax']]
    landmark = [bbox['xlefteye'],bbox['ylefteye'],bbox['xrighteye'],bbox['yrighteye'],bbox['xnose'],bbox['ynose'],
                bbox['xleftmouth'],bbox['yleftmouth'],bbox['xrightmouth'],bbox['yrightmouth']]

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': _bytes_feature(image_buffer),
        'image/label': _int64_feature(class_label),
        'image/roi': _float_feature(roi),
        'image/landmark': _float_feature(landmark)
    }))
    return example
with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
    example = _convert_to_example_simple(image_example, image_data)  
    tfrecord_writer.write(example.SerializeToString())

这里image_example和image_data都需要事先准备好


2、数据解析

def read_tfrecord(tfrecord_file, batch_size):
    filename_queue = tf.train.string_input_producer([tfrecord_file],shuffle=True)
    # read tfrecord
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    image_features = tf.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.FixedLenFeature([], tf.string),
            'image/label': tf.FixedLenFeature([], tf.int64),
            'image/roi': tf.FixedLenFeature([4], tf.float32),
            'image/landmark': tf.FixedLenFeature([10],tf.float32)
        }
    )
    image = tf.decode_raw(image_features['image/encoded'], tf.uint8)
    image = tf.reshape(image, [image_size, image_size, 3])  # 一些预处理
    image = (tf.cast(image, tf.float32)-127.5) / 128

    label = tf.cast(image_features['image/label'], tf.float32)
    roi = tf.cast(image_features['image/roi'],tf.float32)
    landmark = tf.cast(image_features['image/landmark'],tf.float32)
    image, label,roi,landmark = tf.train.batch(
        [image, label,roi,landmark],
        batch_size=batch_size,
        num_threads=2,
        capacity=1 * batch_size
    )
    label = tf.reshape(label, [batch_size])
    roi = tf.reshape(roi,[batch_size,4])
    landmark = tf.reshape(landmark,[batch_size,10])
    return image, label, roi,landmark 

3、训练使用

'''
    other code
'''
image_batch, label_batch, bbox_batch,landmark_batch = read_tfrecord(dataset_dir, BATCH_SIZE)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
#begin 
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
    for step in range(MAX_STEP):
        step = step + 1
        if coord.should_stop():
            break
        image_batch_array, label_batch_array, bbox_batch_array,landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch,landmark_batch])
        '''
            other code
        '''
except tf.errors.OutOfRangeError:
     print("Success!")
finally:
     coord.request_stop()
coord.join(threads)

4、注意事项

  1. tensorflow里面都是operator和tensor,需要sess.run()才能使用
  2. TFRecordReader会一直弹出队列中文件的名字,直到队列为空
  3. 使用前需要先初始化graph:
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

原文链接:http://blog.csdn.net/u011956147/article/details/79290163

猜你喜欢

转载自blog.csdn.net/u011956147/article/details/79290163