Tensorflow读取二进制文件并将数据写进tfrecords后读取tfrecords

    tensorflow可以读取样本长度固定的二进制文件,比如CIFAR-10数据,该二进制数据中一个样本由1字节的label和32*32*3字节的image组成。TFRecords是tensorflow设计的一种内置的文件格式,是一种二进制文件,它能更好地利用内存,更方便复制和移动。

该程序实现tensorflow首先读取CIFAR-10的二进制数据,然后将其保存成tfrecords格式的文件,最后实现对tfrecords文件的读取。


一 Tensorflow读取二进制文件

1、构造文件队列

file_queue = tf.train.string_input_producer(file_list) # file_list:文件列表

2、构建二进制文件阅读器,读取内容(读取一个样本字节大小)

reader = tf.FixedLengthRecordReader(bytes_length) # bytes_length:一个样本字节大小
key, value = reader.read(file_queue)

3、解码内容,二进制文件中读取为uint8格式

label_image = tf.decode_raw(value, tf.uint8)

4、分割出特征值和标签值

label = tf.cast(tf.slice(label_image, [0], [1]), tf.int32)
image = tf.slice(label_image, [1], [3073])

5、对图片的特征值数据进行形状改变[3072]-->[32, 32,3],方便后面批处理

image_reshape = tf.reshape(image, [32,32,3])

6、批处理数据

image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)

二 将数据写入TFRecords文件

1、建立TFRecords存储器

writer = tf.python_io.TFRecordWriter(save_path) # save_path:保存路径+文件名

2、循环将所有样本写入文件,每个样本都要构造example

for i in range(example_num): # example_num:样本个数
    image = image_batch[i].eval().tostring()
    label = label_batch[i].eval()[0]
    
    #构造一个样本的example, 由三种存储格式:Int64List、BytesList、FloatList
    example = tf.train.Example(features=tf.train.Features(features={

        "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
    }))

    # 写入数据, 必须序列化后才能写入
    writer.writer(example.SerializeToString)

# 关闭
writer.close()

三 TFRecords读取

1、构造文件队列

file_queue = tf.train.string_input_producer(file_list)

2、构造文件阅读器,读取一个example

reader = tf.TFRecordReader()
key, value = reader.read(file_queue)

3、解析example

features = tf.parse_sigle_example(value, features={
    
    "iamge": tf.FixedLenFeature([], tf.string), # 第一个参数为shape,一般不指定,第二个参数为数据类型,必须跟存储的一致。
    "label": tf.FixedLenFeature([], tf.int64),
})

4、解码内容,如果读取的是string类型则需要解码,如果是int64或者float32则不需要解码

image = tf.decode_raw(features["iamge"], tf.uint8)

5、固定形状,方便批处理

iamge_reshape = tf.reshape(iamge, [32,32,3])
label = tf.cast(features["label"], tf.int32)

6、批处理

iamge_batch, label_batch = tf.train.batch([iamge_reshape, label], batch_size=10, num_threads=1, capacity=10)

四 完整代码

# -*- coding: utf-8 -*-

"""
--------------------------------------------------------
# @Version : python3.7
# @Author  : wangTongGen
# @File    : CIFAR-10二进制文件读取并写入TFRecords.py
# @Software: PyCharm
# @Time    : 2019/3/26 09:34
--------------------------------------------------------
# @Description: 
--------------------------------------------------------
"""

import os
import tensorflow as tf


os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

class CifarRead():
    """
        完成二进制文件的读取,
    """
    def __init__(self, file_list):
        # 文件列表
        self.file_list = file_list
        # 定义图片的一些属性
        self.height = 32
        self.weight = 32
        self.channel = 3
        # 二进制读取每张图片的字节
        self.label_bytes = 1
        self.image_bytes = self.height * self.weight * self.channel
        # 一个样本的总字节大小
        self.bytes = self.label_bytes + self.image_bytes

    # 读取二进制文件
    def decodeCifar(self):
        # 1.构造文件队列
        file_queue = tf.train.string_input_producer(self.file_list)
        # 2.构造二进制文件读取器,读取内容(读取一个样本的字节数)
        reader = tf.FixedLengthRecordReader(self.bytes) # 指定一次读多少个样本
        key, value = reader.read(file_queue)
        # 3.解码内容,二进制文件内容的解码
        label_image = tf.decode_raw(value, tf.uint8)
        # print(label_image)
        # 4.分割出图片数据和标签数据
        label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
        image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])

        # 5.对图片的特征数据进行形状改变[3072] --> [32,32,3]
        image_reshape = tf.reshape(image, [self.height, self.weight, self.channel])
        # print(label, image_reshape)

        # 6.批处理数据
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=2, capacity=10)
        # print(image_batch, label_batch)

        return image_batch, label_batch

    # 将图片的特征值和目标值保存进tfrecords
    def write_to_tfrecords(self, image_batch, label_batch):

        # 1.建立TFRecord存储器
        save_path = "../datas/cifar.tfrecords"
        writer = tf.python_io.TFRecordWriter(save_path)

        # 2.循环将所有样本写入文件,每个样本都要构造example协议
        for i in range(10):
            #取出第i个样本的特征值和目标值
            image = image_batch[i].eval().tostring()
            label = label_batch[i].eval()[0]

            # 构造一个样本的example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))

            # 写入单独的样本
            writer.write(example.SerializeToString())


        # 关闭
        writer.close()


    def read_from_tfrecords(self):

        # 1.构造文件队列
        file_queue = tf.train.string_input_producer(["../datas/cifar.tfrecords"])

        # 2.构造文件阅读器,读取内容example,value为一个样本序列化example
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)

        print(value)

        # 3.解析example
        features = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64),
        })

        # 4.解码内容,如果读取的是string则需要解码,如果是int64,float32不需要解码
        image = tf.decode_raw(features["image"], tf.uint8)

        # 5.固定图片的形状,方便批处理
        image_reshape = tf.reshape(image, [self.height, self.weight, self.channel])
        label = tf.cast(features["label"], tf.int32)

        # print(image_reshape, label)

        # 6.进行批处理
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)


        return image_batch, label_batch


if __name__ == '__main__':

    file_dir = "/Users/wtg/pycharm-workspace/cifar-10-batches-bin"
    file_name_list = os.listdir(file_dir)
    file_list = [os.path.join(file_dir, file) for file in file_name_list if file[-3:]=="bin"]

    cf = CifarRead(file_list)

    image_batch_from_bin, label_batch_from_bin = cf.decodeCifar()



    with tf.Session() as sess:
        # 定义一个线程协调器
        coord = tf.train.Coordinator()
        # 开启读取文件的线程
        threads = tf.train.start_queue_runners(sess, coord=coord)

        print(sess.run([image_batch_from_bin, label_batch_from_bin]))

        #存进tfrecords文件
        print("开始存储为TFRecords文件")
        cf.write_to_tfrecords(image_batch_from_bin, label_batch_from_bin)
        print("存储结束")

        # print("开始从TFRecords文件读取内容")
        # image_batch_from_TF, label_batch_from_TF = cf.read_from_tfrecords()
        # print("读取结束")
        # print(sess.run([image_batch_from_TF, label_batch_from_TF]))

        coord.request_stop()
        coord.join(threads)

猜你喜欢

转载自blog.csdn.net/qq_41689620/article/details/88829734