Deep learning和tensorflow学习记录(三十二):TFRecord格式保存与读取

一、生成TFRecord数据

首先定义两个函数将features打包成example proto。

import tensorflow as tf

def _int64_feature(value):
  if not isinstance(value, list):
    value = [value]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _float_feature(value):
  if not isinstance(value, list):
    value = [value]
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

创建writer用来写入tfrecord数据。

writer = tf.python_io.TFRecordWriter('test.tfrecord')

创建example,写入到writer中。

for i in range(0, 2):
    a = 5 + i
    b = 2 * (i + 1)
    print("i:", i, a)
    print("i:", i, b)
    example = tf.train.Example(
        features=tf.train.Features(
            feature={'a': _int64_feature(a),
                     'b': _float_feature(b)}))
    serialized = example.SerializeToString()
    writer.write(serialized)
writer.close()

输出:

i: 0 5
i: 0 2
i: 1 6
i: 1 4

这里循环两次生成两个example写入到writer中。

在当前路径下生成了一个test.tfrecord文件。

二、读取TFRecord文件

下面来看一下如何读取tfrecord文件。

filename_queue = tf.train.string_input_producer(['test.tfrecord'], num_epochs=None)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)

features = tf.parse_single_example(serialized_example,
                                   features={
                                       'a': tf.FixedLenFeature([], dtype=tf.int64),
                                       'b': tf.FixedLenFeature([], dtype=tf.float32)
                                   })

现在可以从features里面去数据了。

a_out = features['a']
b_out = features['b']

print(a_out)
print(b_out)

输出:

Tensor("ParseSingleExample/ParseSingleExample:0", shape=(), dtype=int64)
Tensor("ParseSingleExample/ParseSingleExample:1", shape=(), dtype=float32)

在Session run的时候以batch的方式读取数据。batch_size=1,每次只取一个数据。

a_batch, b_batch = tf.train.shuffle_batch([a_out, b_out], batch_size=1, capacity=100, min_after_dequeue=50, num_threads=1)
sess.run(tf.global_variables_initializer())

tf.train.start_queue_runners(sess=sess)
a_val, b_val = sess.run([a_batch, b_batch])
print('first run:')
print(a_val)
print(b_val)
a_val, b_val = sess.run([a_batch, b_batch])
print('second run:')
print(a_val)
print(b_val)

输出:

first run:
[6]
[4.]
second run:
[5]
[2.]

可以看到,虽然我们存储的时候是先存的5,2,后存的6, 4,但是因为我们使用了tf.train.shuffle_batch随机乱序的读取数据,所以第一次运行的时候取到了6,4。

猜你喜欢

转载自blog.csdn.net/heiheiya/article/details/81097033