Leer y escribir archivos tfrecord


Al entrenar el modelo, el preprocesamiento de datos generalmente se convierte al formato tfrecord. La CPU responsable de las operaciones de E / S y la GPU para los cálculos de operaciones numéricas pueden trabajar en paralelo entre sí para asegurar una alta utilización de la GPU. La siguiente es la forma de leer y escribir tfrecord con longitud fija y longitud variable.

1 Escriba tfrecor way

Generalmente, los datos se convierten a la entrada xy la etiqueta en el formato tfrecord de la manera requerida para el entrenamiento del modelo. Existen principalmente dos formas de longitud fija y longitud variable, que se determinan de acuerdo con las aplicaciones y requisitos reales. Si la entrada de cada entrada de ejemplo es de longitud variable, por ejemplo, el número de índices de características de entrada de cada ejemplo no es el mismo, se puede convertir de forma de longitud variable; de ​​lo contrario, se puede convertir de forma fija. manera de la longitud.

1.1 Convertir entidad de longitud variable a tfrecord

import collections
writer = tf.python_io.TFRecordWriter('data.tfrecord')
def toTF(data):
	''' 
	data是一个dict,假设其中key有input_x和input_y,
	对应的value是索引list
	'''
	features = collections.OrderedDict()
	input_x = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x"])))
	features["input_x"] = tf.train.FeatureList(feature=input_x)
	input_y = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y"])))
	features["input_y"] = tf.train.FeatureList(feature=input_y)
	sequence_example = tf.train.SequenceExample(feature_lists=tf.train.FeatureLists(feature_list=features))
	writer.write(sequence_example.SerializeToString())

Los siguientes métodos son equivalentes a los métodos anteriores:

def toTF_v2(data)
	sequence_example = tf.train.SequenceExample()
	input_x = sequence_example.feature_lists.feature_list["input_x"]
	input_y = sequence_example.feature_lists.feature_list["input_y"]
	for x in data["input_x"]:
		input_x.feature.add().int64_list.value.append(x)
	for y in data["input_y"]:
		input_y.feature.add().int64_list.value.append(y)
	writer.write(sequence_example.SerializeToString())

1.2 Conversión de entidades de longitud fija a tfrecord

def toTF_fixed(data):
	features = collections.OrderedDict()
	features["input_x"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x")))
	features["input_y"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y")))
	example = tf.train.Example(features=tf.train.Features(feature=features))
	write.write(example.SerializeToString())

2 Leer tfrecord

Al igual que escribir trrecord, también existen métodos de longitud fija y de longitud variable. Si escribir tfrecord es un método de longitud fija, la lectura de tfrecord también requiere un método de longitud fija. Los métodos de lectura y escritura deben ser coherentes.

2.1 Leer tfrecord en modo de longitud variable

Necesita definir el formato de la característica, si es de longitud variable, defina la característica de tipo tf.FixedLenSequenceFeature

import tensorflow as tf
features = {
    
    
			'input_x': tf.FixedLenSequenceFeature([], tf.int64)
			'input_y': tf.FixedLenSequenceFeature([], tf.int64)
			}

2.2 Leer tfrecord en modo de longitud fija

El método de longitud fija utiliza el tipo tf.FixedLenFeature

seq_length = 10
features = {
    
    
		'input_x': tf.FixedLenFeature([seq_length], tf.int64).
		'input_y': tf.FixedLenFeature([seq_length], tf.int64
		}

3 Leer archivos tfrecord por lotes de hdfs

Cuando la cantidad de datos de entrenamiento es muy grande, generalmente transfiéralos a tfrecord para probar el procesamiento de datos distribuidos para mejorar la eficiencia. Al entrenar el modelo, puede leer archivos por lotes desde un control remoto, como hdfs. La siguiente es una lectura por lotes de archivos tfrecord de hdfs.

def input_fn_builder(file_path, num_cpu_threads, seq_length, num_class, batch_size):
	'''
	其中file_path是hdfs上文件的路径,比如data目录下的所有tfrecord文件
	读的是定长的feature
	'''
	features = {
    
    
			'input_x': tf.FixedLenFeature([seq_length], tf.int64),
			'input_y': tf.FixedLenFeature([seq_length], tf.int64),
	}
	def _decode_record(record):
		# 一个样本解析
		example = tf.io.parse_single_example(record, features)
		multi_label_enc = tf.one_hot(indices=example["input_y"], depth=num_class)
		example["input_y"] = tf.reduce_sum(multi_label_enc, axis=0)
		return example

	def _decode_batch_record(batch_record):
		# 一个batch样本解析
		batch_example = tf.io.parse_example(serialized=batch_record, features=features)
		multi_label_enc = tf.one_hot(indices=batch_example["input_y"], depth=num_class)
		batch_example["input_y"] = tf.reduce_sum(multi_label_enc, axis=1)
		return batch_example

	def input_fn(params):
		# d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
		d = tf.data.Dataset.list_files(file_path)
		d = d.repeat()
		d = d.shuffle(buffer_size=100)
		d = d.appley(
			tf.contrib.data.parallel_interleave(
				tf.data.TFRecordDataset,
				sloppy=True,
				cycle_length=num_cpu_threads))
		d = d.apply(
			tf.contrib.data.map_and_batch(
					lambda record: _decode_record(record),
					batch_size = batch_size,
					num_parallel_batches=num_cpu_threads,
					drop_remainder=True))
		return d

	def input_fn_v2(params):
		d = tf.data.Dataset.list_files(file_path)
		d = d.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=num_cpu_threads, block_length=128).\
		batch(batch_size).map(_decode_batch_record, num_parallel_calls=tf.data.experimental.AUTOTRUE).prefetch(
			tf.data.experimental.AUTOTUNE).repeat()
		return d
	return input_fn
	#return input_fn_v2

Se proporcionan dos funciones analíticas arriba. Ambos métodos input_fn y input_fn_v2 son factibles. Con el entrenamiento del estimador, se puede realizar el procesamiento paralelo entre los datos de lectura de la CPU y los datos de entrenamiento de la GPU, lo que reduce el tiempo de espera, mejora la utilización de la GPU y acelera la velocidad de entrenamiento. Al analizar un archivo tfrecord, hay cuatro formas de elegir de acuerdo con su formato de datos específico:

  • Analizar una sola muestra, características de longitud fija: tf.io.parse_single_example ()
  • Analizar una sola muestra, características de longitud variable: tf.io.parse_single_sequence_example ()
  • Analizar muestras por lotes, características de longitud fija: tf.io.parse_example ()
  • Analizar muestras por lotes, características de longitud fija: tf.io.parse_sequence_example ()

Supongo que te gusta

Origin blog.csdn.net/BGoodHabit/article/details/108559976
Recomendado
Clasificación