TensorFlow之TFRecords文件的存储与读取讲解及代码实现

先聊一下tfrecord, 这是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,TFRecords是TensorFlow中的设计的一种内置的文件格式,优点有如下几种:

  • 统一不同输入文件的框架
  • 它是更好的利用内存,更方便复制和移动(TFRecord压缩的二进制文件, protocal buffer序列化)
  • 是用于将二进制数据和标签(训练的类别标签)数据存储在同一个文件中

一、TFRecords存储

在将其他数据存储为TFRecords文件的时候,需要经过两个步骤:

  • 建立TFRecord存储器
  • 构造每个样本的Example模块

1、建立TFRecord存储器

tf.python_io.TFRecordWriter(path)

  • 写入tfrecords文件
  • path : TFRecords文件的路径
  • return : 写文件
  • 方法: 
    • write(record):向文件中写入一个字符串记录(即一个样本)
    • close() : 关闭文件写入器

注:此处的字符串为一个序列化的Example,通过Example.SerializeToString()来实现,它的作用是将Example中的map压缩为二进制,节约大量空间。

2、构造每个样本的Example协议块

message Example {
  Features features = 1;
};

message Features {
  map<string, Feature> feature = 1;
};

message Feature {
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

上面这段代码即为Example协议块的规则,详解如下: 
(1)tf.train.Example(features = None)

  • 写入tfrecords文件
  • features : tf.train.Features类型的特征实例
  • return : example协议格式块

(2)tf.train.Features(feature = None)

  • 构造每个样本的信息键值对
  • feature : 字典数据,key为要保存的名字,value为tf.train.Feature实例
  • return : Features类型

(3)tf.train.Feature(**options) 
options可以选择如下三种格式数据:

  • bytes_list = tf.train.BytesList(value = [Bytes])
  • int64_list = tf.train.Int64List(value = [Value])
  • float_list = tf.trian.FloatList(value = [Value])

(4)将图片数据转化为TFRecords的例子: 
对每一个样本,都做如下的处理:

example = tf.train.Example(feature = tf.train.Features(feature = {
                            "image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image(bytes)]))
                             "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label(int)]))
    }))
  • 二、TFRecords读取方法

1.流程:

和文件阅读器的流程基本相同,只是中间多了一步解析过程

2.解析TFRecords的example协议内存块:

(1)tf.parse_single_example(serialized,features=None,name= None

  • 解析一个单一的Example原型
  • serialized : 标量字符串的Tensor,一个序列化的Example,文件经过文件阅读器之后的value
  • features :字典数据,key为读取的名字,value为FixedLenFeature
  • return : 一个键值对组成的字典,键为读取的名字

(2)tf.FixedLenFeature(shape,dtype)

  • shape : 输入数据的形状,一般不指定,为空列表
  • dtype : 输入数据类型,与存储进文件的类型要一致,类型只能是float32,int 64, string
  • return : Tensor (即使有零的部分也存储)

(3)上面(1)中features中的value还可以为tf.VarLenFeature(),但是这种方式用的比较少,它返回的是SparseTensor数据,这是一种只存储非零部分的数据格式,了解即可。

三、代码实现

1.将CSV文件转化为TFRecords文件

import tensorflow as tf
import numpy as np
import pandas as pd

train_frame = pd.read_csv("train.csv")
print(train_frame.head())
train_labels_frame = train_frame.pop(item="label")
train_values = train_frame.values
train_labels = train_labels_frame.values
print("values shape: ", train_values.shape)
print("labels shape:", train_labels.shape)

writer = tf.python_io.TFRecordWriter("csv_train.tfrecords")

for i in range(train_values.shape[0]):
    image_raw = train_values[i].tostring()
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels[i]]))
            }
        )
    )
    writer.write(record=example.SerializeToString())

writer.close()

2.将图片文件转化为TFRecords文件

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
import pandas as pd

def get_label_from_filename(filename):
    return 1

filenames = tf.train.match_filenames_once('.\data\*.png')

writer = tf.python_io.TFRecordWriter('png_train.tfrecords')

for filename in filenames:
    img=mpimg.imread(filename)
    print("{} shape is {}".format(filename, img.shape))
    img_raw = img.tostring()
    label = get_label_from_filename(filename)
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }
        )
    )
    writer.write(record=example.SerializeToString())

writer.close()

3.将二进制文件转化为TFRecords文件

"""
读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords
"""

#命令行参数
FLAGS = tf.app.flags.FLAGS       #获取值
tf.app.flags.DEFINE_string("tfrecord_dir","./tmp/cifar10.tfrecords","写入图片数据文件的文件名")


#读取二进制转换文件
class CifarRead(object):
    """
    读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords
    """
    def __init__(self,file_list):
        """
        初始化图片参数
        :param file_list:图片的路径名称列表
        """

        #文件列表
        self.file_list = file_list

        #图片大小,二进制文件字节数
        self.height = 32
        self.width = 32
        self.channel = 3
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.label_bytes + self.image_bytes


    def read_and_decode(self):
        """
        解析二进制文件到张量
        :return: 批处理的image,label张量
        """
        #1.构造文件队列
        file_queue = tf.train.string_input_producer(self.file_list)

        #2.阅读器读取内容
        reader = tf.FixedLengthRecordReader(self.bytes)

        key ,value = reader.read(file_queue)    #key为文件名,value为元组

        print(value)

        #3.进行解码,处理格式
        label_image = tf.decode_raw(value,tf.uint8)
        print(label_image)

        #处理格式,image,label
        #进行切片处理,标签值
        #tf.cast()函数是转换数据格式,此处是将label二进制数据转换成int32格式
        label = tf.cast(tf.slice(label_image,[0],[self.label_bytes]),tf.int32)

        #处理图片数据
        image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])
        print(image)

        #处理图片的形状,提供给批处理
        #因为image的形状已经固定,此处形状用动态形状来改变
        image_tensor = tf.reshape(image,[self.height,self.width,self.channel])
        print(image_tensor)

        #批处理图片数据
        image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)

        return image_batch,label_batch

    def write_to_tfrecords(self,image_batch,label_batch):
        """
        将文件写入到TFRecords文件中
        :param image_batch:
        :param label_batch:
        :return:
        """

        #建立TFRecords文件存储器
        writer = tf.python_io.TFRecordWriter(FLAGS.tfrecord_dir)      #传进去命令行参数

        #循环取出每个样本的值,构造example协议块
        for i in range(10):

            #取出图片的值,  #写进去的是值,而不是tensor类型,
            # 写入example需要bytes文件格式,将tensor转化为bytes用tostring()来转化
            image = image_batch[i].eval().tostring()

            #取出标签值,写入example中需要使用int形式,所以需要强制转换int
            label = int(label_batch[i].eval()[0])

            #构造每个样本的example协议块
            example = tf.train.Example(features = tf.train.Features(feature = {
                "image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])),
                "label":tf.train.Feature(int64_list = tf.train.Int64List(value = [label]))
            }))

            #写进去序列化后的值
            writer.write(example.SerializeToString())     #此处其实是将其压缩成一个二进制数据

        writer.close()

        return None



    def read_from_tfrecords(self):
        """
        从TFRecords文件当中读取图片数据(解析example)
        :param self:
        :return: image_batch,label_batch
        """

        #1.构造文件队列
        file_queue = tf.train.string_input_producer([FLAGS.tfrecord_dir])    #参数为文件名列表

        #2.构造阅读器
        reader = tf.TFRecordReader()

        key,value = reader.read(file_queue)

        #3.解析协议块,返回的值是字典
        feature = tf.parse_single_example(value,features={
            "image":tf.FixedLenFeature([],tf.string),
            "label":tf.FixedLenFeature([],tf.int64)
        })

        #feature["image"],feature["label"]
        #处理标签数据    ,cast()只能在int和float之间进行转换
        label = tf.cast(feature["label"],tf.int32)    #将数据类型int64 转换为int32

        #处理图片数据,由于是一个string,要进行解码,  #将字节转换为数字向量表示,字节为一字符串类型的张量
        #如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型
        # decode_raw()可以将数据从string,bytes转换为int,float类型的
        image = tf.decode_raw(feature["image"],tf.uint8)

        #转换图片的形状,此处需要用动态形状进行转换
        image_tensor = tf.reshape(image,[self.height,self.width,self.channel])

        #4.批处理
        image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)

        return image_batch,label_batch


if __name__ == '__main__':

    # 找到文件路径,名字,构造路径+文件名的列表,"A.csv"...
    # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
    filename = os.listdir('./data/cifar10/cifar-10-batches-bin/')

    #加上路径
    file_list = [os.path.join('./data/cifar10/cifar-10-batches-bin/', file) for file in filename if file[-3:] == "bin"]

    #初始化参数
    cr = CifarRead(file_list)

    #读取二进制文件
    # image_batch,label_batch = cr.read_and_decode()

    #从已经存储的TFRecords文件中解析出原始数据
    image_batch, label_batch = cr.read_from_tfrecords()

    with tf.Session() as sess:
        #线程协调器
        coord = tf.train.Coordinator()

        #开启线程
        threads = tf.train.start_queue_runners(sess,coord=coord)

        print(sess.run([image_batch,label_batch]))

        # print("存进TFRecords文件")
        # cr.write_to_tfrecords(image_batch,label_batch)
        # print("存进文件完毕")

        #回收线程
        coord.request_stop()
        coord.join(threads)
  • 注: 

上段代码分为两个部分:

  • 第一部分是被注释掉的几行代码,表示的是将二进制文件转化为张量,并经过Example协议存储到TFRecords文件当中;
  • 第二部分是从已经存储好数据信息的TFRecords文件中,经过解析,转化为最初的二进制文件。

参考地址:https://blog.csdn.net/chengshuhao1991/article/details/78656724


猜你喜欢

转载自blog.csdn.net/m0_37407756/article/details/80671905