1 CIFAR10数据集
1.1 数据集简介
CIFAR10是一个包含十个类别RGB图像的用于识别普适物体的小型数据集。其内共有50000张训练图片与10000张测试图片,图片尺寸为32*32。
1.2下载数据集中数据
import cifar10
import tensorflow as tf
# tf.app.flags.FLAGS是TensorFlow内部的一个全局变量存储器,同时可以用于命令行参数的处理
FLAGS = tf.app.flags.FLAGS
# 在cifar10模块中预先定义了f.app.flags.FLAGS.data_dir为CIFAR-10的数据路径
FLAGS.data_dir = 'cifar10_data/'
# 如果不存在数据文件,就会执行下载
cifar10.maybe_download_and_extract()
以上两条语句的目的在于将下载的目录地址调整为相对路径下的cifar10_data,而最后一句是检查是否已经下载,若未下载则执行下载命令。当python中显示“下载成功”时,则意味着图片数据下载完毕。
1.3Tensorflow中数据读取机制
在普通计算机读取数据的时候,一般多线程的形式进行数据读取,而对于tf,在内存队列之前又添加了一个“文件名队列”以方便管理。
如上图所示,处理数据程序运行之前,应先把数据放入到文件名队列之中,在程序启动之后,内存将会把文件名队列之中的数据逐条读入,当读到末尾时,将会返回一个异常,以便于标志数据读取完成。
tf.train.string_input_producer
在tf中,我们使用以上这个函数来创建文件名队列。只需要传入一个列表,python将自动给你创建一个队列。此外,这个函数还有两个重要的参数,一个为num_epochs,作用为重复创建次数;另一个为shuffle,作用为是否打乱排序。当num_epochs=2;shuffle=True时,如下所示:
在运行程序之前,我们还需要激活文件名队列,这就是函数train.start_queue_runners的作用。而这个函数的唯一传入参数就是sess,也就是tf.Session()
测试代码为:
# coding:utf-8
import os
if not os.path.exists('read'):
os.makedirs('read/')
# 导入TensorFlow
import tensorflow as tf
# 新建一个Session
with tf.Session() as sess:
# 我们要读三幅图片A.jpg, B.jpg, C.jpg
filename = ['A.jpg', 'B.jpg', 'C.jpg']
# string_input_producer会产生一个文件名队列
filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)
# reader从文件名队列中读数据。对应的方法是reader.read
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
# tf.train.string_input_producer定义了一个epoch变量,要对它进行初始化
tf.local_variables_initializer().run()
# 使用start_queue_runners之后,才会开始填充队列
threads = tf.train.start_queue_runners(sess=sess)
i = 0
while True:
i += 1
# 获取图片数据并保存
image_data = sess.run(value)
with open('read/test_%d.jpg' % i, 'wb') as f:
f.write(image_data)
# 程序最后会抛出一个OutOfRangeError,这是epoch跑完,队列关闭的标志
1.4 将数据集保存为图片形式
1.4.1 数据集内容
在CIFAR数据集中,一个样本由3073个字节构成,第一个字节为标签,之后的字节都是图片数据(32*32*3)
1.4.2 tensorflow读取
首先,通过tf.train.string_input_producer来建立队列
之后,通过reader.read读取数据。在这里,由于我们读取的是多个样本,因此不能使用WholeFileReader,而使用tf.FixedLengthRecordReader来进行数据读取
最后,使用train.start_queue_runners来激活队列,再用sess.run()得到图片结果。
if __name__ == '__main__':
# 创建一个会话sess
with tf.Session() as sess:
# 调用inputs_origin。cifar10_data/cifar-10-batches-bin是我们下载的数据的文件夹位置
reshaped_image = inputs_origin('cifar10_data/cifar-10-batches-bin')
# 这一步start_queue_runner很重要。
# 我们之前有filename_queue = tf.train.string_input_producer(filenames)
# 这个queue必须通过start_queue_runners才能启动
# 缺少start_queue_runners程序将不能执行
threads = tf.train.start_queue_runners(sess=sess)
# 变量初始化
sess.run(tf.global_variables_initializer())
# 创建文件夹cifar10_data/raw/
if not os.path.exists('cifar10_data/raw/'):
os.makedirs('cifar10_data/raw/')
# 保存30张图片
for i in range(30):
# 每次sess.run(reshaped_image),都会取出一张图片
image_array = sess.run(reshaped_image)
scipy.misc.toimage(image_array).save('cifar10_data/raw/%d.jpg' % i)
在这里,inputs_origin作为一个函数起着构造队列与使用reader的作用。它的返回值是一个tensor,对应着一张训练图片,在这之后,我们运行第三步,就能得到对应的图片了。
在inputs_origin中调用的read_cifar10,可以看到:
tf.FixedLengthRecordReader(record_bytes=record_bytes)语句中创建了一个reader,每次将会读取record_bytes字节个数据,直到文件结束。也就是读取了正好一个图片大小的数据量。
def inputs_origin(data_dir):
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in range(1, 6)]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
filename_queue = tf.train.string_input_producer(filenames)
read_input = cifar10_input.read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
return reshaped_image
def read_cifar10(filename_queue):
class CIFAR10Record(object):
pass
result = CIFAR10Record()
label_bytes = 1 # 2 for CIFAR-100
result.height = 32
result.width = 32
result.depth = 3
image_bytes = result.height * result.width * result.depth
record_bytes = label_bytes + image_bytes
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
result.key, value = reader.read(filename_queue)
# Convert from a string to a vector of uint8 that is record_bytes long.
record_bytes = tf.decode_raw(value, tf.uint8)
# The first bytes represent the label, which we convert from uint8->int32.
result.label = tf.cast(
tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
depth_major = tf.reshape(
tf.strided_slice(record_bytes, [label_bytes],
[label_bytes + image_bytes]),
[result.depth, result.height, result.width])
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
return result