数据集类型转换—TFRecords文件

TFRecord 是 TensorFlow 中的数据集存储格式。当我们将数据集整理成 TFRecord 格式后,TensorFlow 就可以高效地读取和处理这些数据集,从而帮助我们更高效地进行大规模的模型训练。

格式:TFRecord 可以理解为一系列序列化的 tf.train.Example 元素所组成的列表文件,而每一个 tf.train.Example 又由若干个 tf.train.Feature 的字典组成。形式如下:

[
    {
    
       # example 1 (tf.train.Example)
        'feature_1': tf.train.Feature,
        ...
        'feature_k': tf.train.Feature
    },
    ...
    {
    
       # example N (tf.train.Example)
        'feature_1': tf.train.Feature,
        ...
        'feature_k': tf.train.Feature
    }
]
# 字典结构如
feature = {
    
    
                'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])), 
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))  
            }
  • 保存TFRecord
    • 为了将形式各样的数据集整理为 TFRecord 格式,我们需要对数据集中的每个元素进行以下步骤:
      读取该数据元素到内存
    • 将该元素转换为 tf.train.Example 对象(每一个 tf.train.Example 由若干个 tf.train.Feature 的字典组成,因此需要先建立 Feature 的字典);
    • 将该 tf.train.Example 对象序列化为字符串,并通过一个预先定义的 tf.io.TFRecordWriter 写入 TFRecord 文件。
  • 读取 TFRecord 数据
    • 通过 tf.data.TFRecordDataset 读入原始的 TFRecord 文件(此时文件中的 tf.train.Example 对象尚未被反序列化),获得一个 tf.data.Dataset 数据集对象;
    • 通过 Dataset.map 方法,对该数据集对象中的每一个序列化的 tf.train.Example 字符串执行 tf.io.parse_single_example 函数,从而实现反序列化。

代码:

# -*- coding: utf-8 -*-
"""
@Time    : 2022-08-03 22:02
@Author  : peyzhang
@File    : tfRecode.py
@Software: PyCharm
"""

import tensorflow as tf
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

train_cars_dir = 'F:/AI/data//train/car/'
train_human_dir = 'F:/AI/data//train/dog/'
tfrecord_file = 'F:/AI/data//train.tfrecords'


def main():

    train_car_filenames =  [train_cars_dir + filename for filename in os.listdir(train_cars_dir)]
    train_dog_filenames =  [train_human_dir + filename for filename in os.listdir(train_human_dir)]
    train_filename = train_car_filenames + train_dog_filenames
    train_labels = [0] * len(train_car_filenames) + [1] *len(train_dog_filenames)
    print(train_filename)
    print(train_labels)

    with tf.io.TFRecordWriter(tfrecord_file) as writer:
        for filename , label in zip(train_filename, train_labels):
            image = open(filename,"rb").read()
            feature = {
    
    
                "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }

            example = tf.train.Example(features=tf.train.Features(feature=feature))
            writer.write(example.SerializeToString())
def readTFrecord():
    raw_dataset = tf.data.TFRecordDataset(tfrecord_file)
    feature_description = {
    
    
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }

    def decoding(example_string):
        fdict = tf.io.parse_single_example(example_string, feature_description)
        fdict['image'] = tf.io.decode_jpeg(fdict['image'])
        return fdict['image'], fdict['label']

    dataseate = raw_dataset.map(decoding)
    for iamge , label in dataseate:
        print(iamge, label)

if __name__ == '__main__':
    # main()
    readTFrecord()



猜你喜欢

转载自blog.csdn.net/Peyzhang/article/details/126150626
今日推荐