TFRecord é o formato de armazenamento do conjunto de dados no TensorFlow. Quando organizamos os conjuntos de dados no formato TFRecord, o TensorFlow pode ler e processar com eficiência esses conjuntos de dados, ajudando-nos a conduzir o treinamento de modelos em larga escala com mais eficiência.
Formato: TFRecord pode ser entendido como um arquivo de lista composto por uma série de elementos tf.train.Example serializados, e cada tf.train.Example é composto por vários dicionários tf.train.Feature. O formulário é o seguinte:
[
{
# example 1 (tf.train.Example)
'feature_1': tf.train.Feature,
...
'feature_k': tf.train.Feature
},
...
{
# example N (tf.train.Example)
'feature_1': tf.train.Feature,
...
'feature_k': tf.train.Feature
}
]
# 字典结构如
feature = {
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}
- Salvar TFRecord
- Para organizar vários conjuntos de dados no formato TFRecord, precisamos executar as seguintes etapas para cada elemento no conjunto de dados:
ler o elemento de dados na memória - Converta este elemento em um objeto tf.train.Example (cada tf.train.Example consiste em vários dicionários tf.train.Feature, então você precisa criar um dicionário Feature primeiro);
- Serialize o objeto tf.train.Example em uma string e grave-o em um arquivo TFRecord por meio de um tf.io.TFRecordWriter predefinido.
- Para organizar vários conjuntos de dados no formato TFRecord, precisamos executar as seguintes etapas para cada elemento no conjunto de dados:
- Ler dados TFRecord
- Leia o arquivo TFRecord original por meio de tf.data.TFRecordDataset (neste momento, o objeto tf.train.Example no arquivo não foi desserializado) e obtenha um objeto de conjunto de dados tf.data.Dataset;
- Por meio do método Dataset.map, execute a função tf.io.parse_single_example para cada string tf.train.Example serializada no objeto de conjunto de dados para obter a desserialização.
código:
# -*- coding: utf-8 -*-
"""
@Time : 2022-08-03 22:02
@Author : peyzhang
@File : tfRecode.py
@Software: PyCharm
"""
import tensorflow as tf
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
train_cars_dir = 'F:/AI/data//train/car/'
train_human_dir = 'F:/AI/data//train/dog/'
tfrecord_file = 'F:/AI/data//train.tfrecords'
def main():
train_car_filenames = [train_cars_dir + filename for filename in os.listdir(train_cars_dir)]
train_dog_filenames = [train_human_dir + filename for filename in os.listdir(train_human_dir)]
train_filename = train_car_filenames + train_dog_filenames
train_labels = [0] * len(train_car_filenames) + [1] *len(train_dog_filenames)
print(train_filename)
print(train_labels)
with tf.io.TFRecordWriter(tfrecord_file) as writer:
for filename , label in zip(train_filename, train_labels):
image = open(filename,"rb").read()
feature = {
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
def readTFrecord():
raw_dataset = tf.data.TFRecordDataset(tfrecord_file)
feature_description = {
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64),
}
def decoding(example_string):
fdict = tf.io.parse_single_example(example_string, feature_description)
fdict['image'] = tf.io.decode_jpeg(fdict['image'])
return fdict['image'], fdict['label']
dataseate = raw_dataset.map(decoding)
for iamge , label in dataseate:
print(iamge, label)
if __name__ == '__main__':
# main()
readTFrecord()