TensorFlow (八): TFRecord

前言

TFRecord 这部分内容困扰了我好几天, 不仅是它的 API 十分晦涩且繁琐, 而且网上的大多数相关教程写的都很抽象, 在看了相关的教材之后才终于有了眉目.

TFRecord 的意义在于: 如果你要训练上万张图片, 他们全部塞进内存里可能需要占用数十甚至上百 GB 的空间, 这时候传统的 feed_dict 方式就不能用了, 需要使用 TFRecord 方式, 即将所有数据转换为 TFRecord 格式的二进制文件, 通过调用 TensorFlow 的相关 API 实现高速的顺序读取, 这样可以在内存有限的情况下完成大体积数据集的训练.

Save

首先是将数据转化为 TFRecord 的代码, 它的 API 真的很晦涩…
它将数据集划分为一个一个的 Example, 比如在图片识别的场景下, 一个 Example 包含一张图片的数据和这张图片对应的标签. 然后通过 tf.python_io.TFRecordWriter 将其写入到指定文件中

代码样例

import tensorflow as tf 
import numpy as np

def int64_feature(value):
	return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))

def create_example(x, y):
	return tf.train.Example(features = tf.train.Features(feature = {
		"x": int64_feature(x),
		"y": int64_feature(y)
		}))

def save_record(x, y, index):
	file_name = './data_%d.record' % index
	with tf.python_io.TFRecordWriter(file_name) as writer:
		for i, (_x, _y) in enumerate(zip(x, y)):
			writer.write(create_example(_x, _y).SerializeToString())

def main(_):
	# 这里我生成了 1000 组数据, 每组数据包含一个 x 和一个 y
	x = np.arange(1000).astype(np.int64)
	y = np.arange(1000).astype(np.int64)

	# 将这些数据写入到 10 个文件中
	for i in range(10):
		save_record(x, y, i)

if __name__ == '__main__':
	tf.app.run()

Load

从 TFRecord 文件读取的逻辑是:

  1. 通过 tf.TFRecordReader 读取文件
  2. 每次读取一个 example (二进制格式), 然后通过调用 tf.parse_single_example() 将该 example 解码为原本的数据, 从而获得了数据
  3. 如果你还想要获取一个 batch, 通过 tf.train.shuffle_batch 获取一个局部随机的 batch, 它的逻辑是, 如果你想要一组 16 个数据的 batch, 它会先读比如 2000 个 example, 然后从这 2000 个里面随机给你 16 个数据作为一个 batch, 因此如果你的数据原本是有序的, 那么通过这个方法得到的"随机" batch 其实是一个局部的随机, 我的办法是在将数据转化为 TFRecord 文件之前就打乱其顺序.

代码样例

import tensorflow as tf 
import numpy as np 
import os

cwd = os.getcwd()
paths = []
for i in range(10):
	paths.append(os.path.join(cwd, 'data_%d.record' % i))
filename_queue = tf.train.string_input_producer(paths, num_epochs = 1)

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)

features = tf.parse_single_example(serialized_example, features = {
	'x': tf.FixedLenFeature([], tf.int64),
	'y': tf.FixedLenFeature([], tf.int64)
	})

batchs = tf.train.shuffle_batch([x, y], batch_size = 16, capacity = 100, min_after_dequeue=50)

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	sess.run(tf.local_variables_initializer())
	coord = tf.train.Coordinator()
	threads = tf.train.start_queue_runners(sess = sess, coord = coord)

	try:
		while not coord.should_stop():
			print(sess.run(batchs))
	except tf.errors.OutOfRangeError:
		print('done')
	finally:
		coord.request_stop()
	coord.join(threads)

总结

这部分内容真的把我搞了, 为什么能设计的这么复杂?
我用了两三天的时间才接受了它的设定, 这离我的目标 ( Fine-tuning pre-trained model with large dataset) 更近了一步.

猜你喜欢

转载自blog.csdn.net/vinceee__/article/details/88430708