TensorFlow之tfrecords文件

 tfrecords的分析与存储实例

TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。总而言之,这样的文件格式好处多多,所以让我们用起来吧。

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

从TFRecords文件中读取数据, 可以使用tf.TFRecordReadertf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。
 

import os
import tensorflow as tf


# 定义自定义命令行参数
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('cifar_dir','./data/cifar10/cifar-10-batches-bin','文件的目录')
tf.app.flags.DEFINE_string('cifar_tfrecords','./tmp/cifar.tfrecords','存储tfrecords的文件')


class CifarRead(object):
    '''
    完成读取二进制文件,写进tfrecords,读取tfrecords
    :param object:
    :return:
    '''
    def __init__(self,filelist):
        # 文件列表
        self.file_list = filelist

        # 定义读取的图片的一些属性
        self.height = 32
        self.width = 32
        self.channel = 3
        # 二进制文件每张图片的字节
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.label_bytes + self.image_bytes

    def read_and_decode(self):
        # 1. 构建文件队列
        file_queue = tf.train.string_input_producer(self.file_list)

        # 2. 构建二进制文件读取器,读取内容,每个样本的字节数
        reader = tf.FixedLengthRecordReader(self.bytes)

        key,value = reader.read(file_queue)

        # 3. 解码内容,二进制文件内容的解码 label_image包含目标值和特征值
        label_image = tf.decode_raw(value,tf.uint8)
        print(label_image)

        # 4.分割出图片和标签数据,特征值和目标值
        label = tf.slice(label_image,[0],[self.label_bytes])

        image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])
        print('---->')
        print(image)

        # 5. 可以对图片的特征数据进行形状的改变 [3072]-->[32,32,3]
        image_reshape = tf.reshape(image,[self.height,self.width,self.channel])

        print('======>')
        print(label)
        print('======>')

        # 6. 批处理数据
        image_batch,label_batch = tf.train.batch([image_reshape,label],batch_size=10,num_threads=1,capacity=10)

        print(image_batch,label_batch)
        return image_batch,label_batch

    def write_ro_tfrecords(self,image_batch,label_batch):
        '''
        将图片的特征值和目标值存进tfrecords
        :param image_batch: 10张图片的特征值
        :param label_batch: 10张图片的目标值
        :return: None
        '''
        # 1.建立TFRecord存储器
        writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)

        # 2. 循环将所有样本写入文件,每张图片样本都要构造example协议
        for i in range(10):
            # 取出第i个图片数据的特征值和目标值
            image = image_batch[i].eval().tostring()

            label = int(label_batch[i].eval()[0])


            # 构造一个样本的example
            example = tf.train.Example(features=tf.train.Features(feature={
                'image':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            }))


            # 写入单独的样本
            writer.write(example.SerializeToString())

        # 关闭
        writer.close()
        return None

if __name__ == '__main__':
    # 找到文件,构建列表  路径+名字  ->列表当中
    file_name = os.listdir(FLAGS.cifar_dir)

    # 拼接路径 重新组成列表
    filelist = [os.path.join(FLAGS.cifar_dir,file) for file in file_name if file[-3:] == 'bin']

    # 调用函数传参
    cf = CifarRead(filelist)
    image_batch,label_batch = cf.read_and_decode()

    # 开启会话
    with tf.Session() as sess:
        # 定义一个线程协调器
        coord = tf.train.Coordinator()

        # 开启读文件的线程
        threads = tf.train.start_queue_runners(sess,coord=coord)

        # 存进tfrecords文件
        print('开始存储')
        cf.write_ro_tfrecords(image_batch,label_batch)
        print('结束存储')
        # 打印读取的内容
        # print(sess.run([image_batch,label_batch]))

        # 回收子线程
        coord.request_stop()

        coord.join(threads)

 tfrecords的读取实例

import os
import tensorflow as tf

# 定义cifar的数据等命令行参数
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('cifar_dir', './data/cifar10/cifar-10-batches-bin', '文件的目录')
tf.app.flags.DEFINE_string('cifar_tfrecords', './tmp/cifar.tfrecords', '存储tfrecords的文件')


class CifarRead(object):
    '''
    完成读取二进制文件,写进tfrecords,读取tfrecords
    :param object:
    :return:
    '''

    def __init__(self, filelist):
        # 文件列表
        self.file_list = filelist

        # 定义读取的图片的一些属性
        self.height = 32
        self.width = 32
        self.channel = 3
        # 二进制文件每张图片的字节
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.label_bytes + self.image_bytes

    def read_and_decode(self):
        # 1. 构建文件队列
        file_queue = tf.train.string_input_producer(self.file_list)

        # 2. 构建二进制文件读取器,读取内容,每个样本的字节数
        reader = tf.FixedLengthRecordReader(self.bytes)

        key, value = reader.read(file_queue)

        # 3. 解码内容,二进制文件内容的解码 label_image包含目标值和特征值
        label_image = tf.decode_raw(value, tf.uint8)
        print(label_image)

        # 4.分割出图片和标签数据,特征值和目标值
        label = tf.slice(label_image, [0], [self.label_bytes])

        image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
        print('---->')
        print(image)

        # 5. 可以对图片的特征数据进行形状的改变 [3072]-->[32,32,3]
        image_reshape = tf.reshape(image, [self.height, self.width, self.channel])

        print('======>')
        print(label)
        print('======>')

        # 6. 批处理数据
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)

        print(image_batch, label_batch)

        return image_batch, label_batch
    # 读取并存储tfrecords文件
    # def write_ro_tfrecords(self, image_batch, label_batch):
    #     '''
    #     将图片的特征值和目标值存进tfrecords
    #     :param image_batch: 10张图片的特征值
    #     :param label_batch: 10张图片的目标值
    #     :return: None
    #     '''
    #     # 1.建立TFRecord存储器
    #     writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)
    #
    #     # 2. 循环将所有样本写入文件,每张图片样本都要构造example协议
    #     for i in range(10):
    #         # 取出第i个图片数据的特征值和目标值
    #         image = image_batch[i].eval().tostring()
    #
    #         label = int(label_batch[i].eval()[0])
    #
    #         # 构造一个样本的example
    #         example = tf.train.Example(features=tf.train.Features(feature={
    #             'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
    #             'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
    #         }))
    #
    #         # 写入单独的样本
    #         writer.write(example.SerializeToString())
    #
    #     # 关闭
    #     writer.close()
    #     return None

    def read_from_tfrecords(self):
        # 1. 构造文件队列
        file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])

        # 2. 构造文件阅读器,读取内容example,value一个样本的序列化example
        reader = tf.TFRecordReader()

        key, value = reader.read(file_queue)

        # 3. 解析example
        features = tf.parse_single_example(value, features={
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        })

        print(features['image'], features['label'])

        # 4. 解码内容,如果读取的内容格式是string需要解码,如果是int64,float32不需要解码
        image = tf.decode_raw(features['image'], tf.uint8)

        # 固定图片的形状,方便与批处理
        image_reshape = tf.reshape(image, [self.height, self.width, self.channel])

        label = tf.cast(features['label'], tf.int32)

        print(image_reshape, label)

        # 进行批处理
        image_batch,label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)

        return image_batch,label_batch


if __name__ == '__main__':
    # 找到文件,构建列表  路径+名字  ->列表当中
    file_name = os.listdir(FLAGS.cifar_dir)

    # 拼接路径 重新组成列表
    filelist = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == 'bin']

    # 调用函数传参
    cf = CifarRead(filelist)
    # image_batch,label_batch = cf.read_and_decode()

    image_batch, label_batch = cf.read_from_tfrecords()

    # 开启会话
    with tf.Session() as sess:
        # 定义一个线程协调器
        coord = tf.train.Coordinator()

        # 开启读文件的线程
        threads = tf.train.start_queue_runners(sess, coord=coord)

        # 存进tfrecords文件
        # print('开始存储')
        # cf.write_ro_tfrecords(image_batch,label_batch)
        # print('结束存储')
        # 打印读取的内容
        print(sess.run([image_batch,label_batch]))

        # 回收子线程
        coord.request_stop()

        coord.join(threads)

猜你喜欢

转载自blog.csdn.net/qq_40716944/article/details/84324497