tensorflow可以读取样本长度固定的二进制文件,比如CIFAR-10数据,该二进制数据中一个样本由1字节的label和32*32*3字节的image组成。TFRecords是tensorflow设计的一种内置的文件格式,是一种二进制文件,它能更好地利用内存,更方便复制和移动。
该程序实现tensorflow首先读取CIFAR-10的二进制数据,然后将其保存成tfrecords格式的文件,最后实现对tfrecords文件的读取。
一 Tensorflow读取二进制文件
1、构造文件队列
file_queue = tf.train.string_input_producer(file_list) # file_list:文件列表
2、构建二进制文件阅读器,读取内容(读取一个样本字节大小)
reader = tf.FixedLengthRecordReader(bytes_length) # bytes_length:一个样本字节大小
key, value = reader.read(file_queue)
3、解码内容,二进制文件中读取为uint8格式
label_image = tf.decode_raw(value, tf.uint8)
4、分割出特征值和标签值
label = tf.cast(tf.slice(label_image, [0], [1]), tf.int32)
image = tf.slice(label_image, [1], [3073])
5、对图片的特征值数据进行形状改变[3072]-->[32, 32,3],方便后面批处理
image_reshape = tf.reshape(image, [32,32,3])
6、批处理数据
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
二 将数据写入TFRecords文件
1、建立TFRecords存储器
writer = tf.python_io.TFRecordWriter(save_path) # save_path:保存路径+文件名
2、循环将所有样本写入文件,每个样本都要构造example
for i in range(example_num): # example_num:样本个数
image = image_batch[i].eval().tostring()
label = label_batch[i].eval()[0]
#构造一个样本的example, 由三种存储格式:Int64List、BytesList、FloatList
example = tf.train.Example(features=tf.train.Features(features={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
# 写入数据, 必须序列化后才能写入
writer.writer(example.SerializeToString)
# 关闭
writer.close()
三 TFRecords读取
1、构造文件队列
file_queue = tf.train.string_input_producer(file_list)
2、构造文件阅读器,读取一个example
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
3、解析example
features = tf.parse_sigle_example(value, features={
"iamge": tf.FixedLenFeature([], tf.string), # 第一个参数为shape,一般不指定,第二个参数为数据类型,必须跟存储的一致。
"label": tf.FixedLenFeature([], tf.int64),
})
4、解码内容,如果读取的是string类型则需要解码,如果是int64或者float32则不需要解码
image = tf.decode_raw(features["iamge"], tf.uint8)
5、固定形状,方便批处理
iamge_reshape = tf.reshape(iamge, [32,32,3])
label = tf.cast(features["label"], tf.int32)
6、批处理
iamge_batch, label_batch = tf.train.batch([iamge_reshape, label], batch_size=10, num_threads=1, capacity=10)
四 完整代码
# -*- coding: utf-8 -*-
"""
--------------------------------------------------------
# @Version : python3.7
# @Author : wangTongGen
# @File : CIFAR-10二进制文件读取并写入TFRecords.py
# @Software: PyCharm
# @Time : 2019/3/26 09:34
--------------------------------------------------------
# @Description:
--------------------------------------------------------
"""
import os
import tensorflow as tf
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
class CifarRead():
"""
完成二进制文件的读取,
"""
def __init__(self, file_list):
# 文件列表
self.file_list = file_list
# 定义图片的一些属性
self.height = 32
self.weight = 32
self.channel = 3
# 二进制读取每张图片的字节
self.label_bytes = 1
self.image_bytes = self.height * self.weight * self.channel
# 一个样本的总字节大小
self.bytes = self.label_bytes + self.image_bytes
# 读取二进制文件
def decodeCifar(self):
# 1.构造文件队列
file_queue = tf.train.string_input_producer(self.file_list)
# 2.构造二进制文件读取器,读取内容(读取一个样本的字节数)
reader = tf.FixedLengthRecordReader(self.bytes) # 指定一次读多少个样本
key, value = reader.read(file_queue)
# 3.解码内容,二进制文件内容的解码
label_image = tf.decode_raw(value, tf.uint8)
# print(label_image)
# 4.分割出图片数据和标签数据
label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
# 5.对图片的特征数据进行形状改变[3072] --> [32,32,3]
image_reshape = tf.reshape(image, [self.height, self.weight, self.channel])
# print(label, image_reshape)
# 6.批处理数据
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=2, capacity=10)
# print(image_batch, label_batch)
return image_batch, label_batch
# 将图片的特征值和目标值保存进tfrecords
def write_to_tfrecords(self, image_batch, label_batch):
# 1.建立TFRecord存储器
save_path = "../datas/cifar.tfrecords"
writer = tf.python_io.TFRecordWriter(save_path)
# 2.循环将所有样本写入文件,每个样本都要构造example协议
for i in range(10):
#取出第i个样本的特征值和目标值
image = image_batch[i].eval().tostring()
label = label_batch[i].eval()[0]
# 构造一个样本的example
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
# 写入单独的样本
writer.write(example.SerializeToString())
# 关闭
writer.close()
def read_from_tfrecords(self):
# 1.构造文件队列
file_queue = tf.train.string_input_producer(["../datas/cifar.tfrecords"])
# 2.构造文件阅读器,读取内容example,value为一个样本序列化example
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
print(value)
# 3.解析example
features = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64),
})
# 4.解码内容,如果读取的是string则需要解码,如果是int64,float32不需要解码
image = tf.decode_raw(features["image"], tf.uint8)
# 5.固定图片的形状,方便批处理
image_reshape = tf.reshape(image, [self.height, self.weight, self.channel])
label = tf.cast(features["label"], tf.int32)
# print(image_reshape, label)
# 6.进行批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
return image_batch, label_batch
if __name__ == '__main__':
file_dir = "/Users/wtg/pycharm-workspace/cifar-10-batches-bin"
file_name_list = os.listdir(file_dir)
file_list = [os.path.join(file_dir, file) for file in file_name_list if file[-3:]=="bin"]
cf = CifarRead(file_list)
image_batch_from_bin, label_batch_from_bin = cf.decodeCifar()
with tf.Session() as sess:
# 定义一个线程协调器
coord = tf.train.Coordinator()
# 开启读取文件的线程
threads = tf.train.start_queue_runners(sess, coord=coord)
print(sess.run([image_batch_from_bin, label_batch_from_bin]))
#存进tfrecords文件
print("开始存储为TFRecords文件")
cf.write_to_tfrecords(image_batch_from_bin, label_batch_from_bin)
print("存储结束")
# print("开始从TFRecords文件读取内容")
# image_batch_from_TF, label_batch_from_TF = cf.read_from_tfrecords()
# print("读取结束")
# print(sess.run([image_batch_from_TF, label_batch_from_TF]))
coord.request_stop()
coord.join(threads)