Read and write tfrecord files


When training the model, data preprocessing is generally converted into the tfrecord format. The CPU responsible for I/O operations and the GPU for numerical operation calculations can work in parallel with each other to ensure high GPU utilization. The following is the way to read and write tfrecord with fixed length and variable length.

1 Write tfrecor way

Generally, the data is converted to the input x and label in the tfrecord format in the manner required for model training. There are mainly two ways of fixed length and variable length, which are determined according to actual applications and requirements. If the input of each example input is of variable length, for example, the number of input feature indexes of each example is not the same, it can be converted in a variable-length manner, otherwise, it can be converted in a fixed-length manner.

1.1 Convert variable length feature to tfrecord

import collections
writer = tf.python_io.TFRecordWriter('data.tfrecord')
def toTF(data):
	''' 
	data是一个dict,假设其中key有input_x和input_y,
	对应的value是索引list
	'''
	features = collections.OrderedDict()
	input_x = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x"])))
	features["input_x"] = tf.train.FeatureList(feature=input_x)
	input_y = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y"])))
	features["input_y"] = tf.train.FeatureList(feature=input_y)
	sequence_example = tf.train.SequenceExample(feature_lists=tf.train.FeatureLists(feature_list=features))
	writer.write(sequence_example.SerializeToString())

The following methods are equivalent to the above methods:

def toTF_v2(data)
	sequence_example = tf.train.SequenceExample()
	input_x = sequence_example.feature_lists.feature_list["input_x"]
	input_y = sequence_example.feature_lists.feature_list["input_y"]
	for x in data["input_x"]:
		input_x.feature.add().int64_list.value.append(x)
	for y in data["input_y"]:
		input_y.feature.add().int64_list.value.append(y)
	writer.write(sequence_example.SerializeToString())

1.2 Converting fixed-length features to tfrecord

def toTF_fixed(data):
	features = collections.OrderedDict()
	features["input_x"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x")))
	features["input_y"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y")))
	example = tf.train.Example(features=tf.train.Features(feature=features))
	write.write(example.SerializeToString())

2 Read tfrecord

Like writing trrecord, there are also fixed-length and variable-length methods. If writing tfrecord is a fixed-length method, then reading tfrecord also requires a fixed-length method. The read and write methods need to be consistent.

2.1 Read tfrecord in variable length mode

Need to define the format of the feature, if it is variable length, define the tf.FixedLenSequenceFeature type feature

import tensorflow as tf
features = {
    
    
			'input_x': tf.FixedLenSequenceFeature([], tf.int64)
			'input_y': tf.FixedLenSequenceFeature([], tf.int64)
			}

2.2 Read tfrecord in fixed length mode

Fixed length method uses tf.FixedLenFeature type

seq_length = 10
features = {
    
    
		'input_x': tf.FixedLenFeature([seq_length], tf.int64).
		'input_y': tf.FixedLenFeature([seq_length], tf.int64
		}

3 Read batch tfrecord files from hdfs

When the amount of training data is very large, generally transfer to tfrecord to try distributed data processing to improve efficiency. When training the model, you can read batch files from a remote, such as hdfs. The following is a batch reading of tfrecord files from hdfs.

def input_fn_builder(file_path, num_cpu_threads, seq_length, num_class, batch_size):
	'''
	其中file_path是hdfs上文件的路径,比如data目录下的所有tfrecord文件
	读的是定长的feature
	'''
	features = {
    
    
			'input_x': tf.FixedLenFeature([seq_length], tf.int64),
			'input_y': tf.FixedLenFeature([seq_length], tf.int64),
	}
	def _decode_record(record):
		# 一个样本解析
		example = tf.io.parse_single_example(record, features)
		multi_label_enc = tf.one_hot(indices=example["input_y"], depth=num_class)
		example["input_y"] = tf.reduce_sum(multi_label_enc, axis=0)
		return example

	def _decode_batch_record(batch_record):
		# 一个batch样本解析
		batch_example = tf.io.parse_example(serialized=batch_record, features=features)
		multi_label_enc = tf.one_hot(indices=batch_example["input_y"], depth=num_class)
		batch_example["input_y"] = tf.reduce_sum(multi_label_enc, axis=1)
		return batch_example

	def input_fn(params):
		# d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
		d = tf.data.Dataset.list_files(file_path)
		d = d.repeat()
		d = d.shuffle(buffer_size=100)
		d = d.appley(
			tf.contrib.data.parallel_interleave(
				tf.data.TFRecordDataset,
				sloppy=True,
				cycle_length=num_cpu_threads))
		d = d.apply(
			tf.contrib.data.map_and_batch(
					lambda record: _decode_record(record),
					batch_size = batch_size,
					num_parallel_batches=num_cpu_threads,
					drop_remainder=True))
		return d

	def input_fn_v2(params):
		d = tf.data.Dataset.list_files(file_path)
		d = d.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=num_cpu_threads, block_length=128).\
		batch(batch_size).map(_decode_batch_record, num_parallel_calls=tf.data.experimental.AUTOTRUE).prefetch(
			tf.data.experimental.AUTOTUNE).repeat()
		return d
	return input_fn
	#return input_fn_v2

Two analytic functions are provided above. Both input_fn and input_fn_v2 methods are feasible. With estimator training, parallel processing can be performed between CPU read data and GPU training data, reducing waiting time, improving GPU utilization, and accelerating training speed. When parsing a tfrecord file, there are four ways to choose according to your specific data format:

  • Parse a single sample, fixed-length features: tf.io.parse_single_example()
  • Parse a single sample, variable length features: tf.io.parse_single_sequence_example()
  • Parse batch samples, fixed-length features: tf.io.parse_example()
  • Parse batch samples, fixed-length features: tf.io.parse_sequence_example()

Guess you like

Origin blog.csdn.net/BGoodHabit/article/details/108559976