将CIFAR-10 数据集保存为图片形式

介绍了TensorFlow 的数据读取的基本原理,再来看如何i卖取CIFAR-10数据。在CIFAR-10 数据集中,文件data batch I .bin 、data batch 2.bin 、data_batch 5 . bin 和test_ batch.bin 中各有10000 个样本。一个样本由3073 个字节组成,第一个字节为标签( label ),剩下3072 个字节为图像数据(官方说明文档)。

样本和样本之间没高多余的字节分割, 因此这几个二进制文件的大小都是30730000 字节。如何用TensorFlow读取CIFAR-10 数据呢?,步骤和上一篇文章(TensorFlow的数据读取机制)一样。

  • 第一步,用tf. train .string_ input producer 建立队列。
  • 第二步,通过reader.read 读数据。在上一篇文章中,一个文件就是一张图片,因此用的reader 是tf. WholeFileReader() 。CIFAR-10 数据是以固定字节存在文件中的,一个文件中含再多个样本3 因此不能使用tf. WholeFileReader (),而是用tf.FixedLengthRecordReader() 。
  • 第三步,调用tf. train . start_ queue_ runners 。
  • 最后,通过sess.run()取出图片结果。

遵循上面的步骤,本文会做一个实验:将CIFAR-10 数据集中的图片读取出来,并保存为.jpg 恪式。对应的程序为cifar 10 extract. py 。看步骤中的tf. train.string_ input_produ cer ,tf.FixedLengthRecordReader ()、tf.train.start_queue_ runners 、sess.run ()都在什么地方。按照程序的执行顺序来看:

#coding: utf-8
# 导入当前目录的cifar10_input,这个模块负责读入cifar10数据
import cifar10_input
# 导入TensorFlow和其他一些可能用到的模块。
import tensorflow as tf
import os
import scipy.misc

def inputs_origin(data_dir):
  # filenames一共5个,从data_batch_1.bin到data_batch_5.bin
  # 读入的都是训练图像
  filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
               for i in xrange(1, 6)]
  # 判断文件是否存在
  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)
  # 将文件名的list包装成TensorFlow中queue的形式
  filename_queue = tf.train.string_input_producer(filenames)
  # cifar10_input.read_cifar10是事先写好的从queue中读取文件的函数
  # 返回的结果read_input的属性uint8image就是图像的Tensor
  read_input = cifar10_input.read_cifar10(filename_queue)
  # 将图片转换为实数形式
  reshaped_image = tf.cast(read_input.uint8image, tf.float32)
  # 返回的reshaped_image是一张图片的tensor
  # 我们应当这样理解reshaped_image:每次使用sess.run(reshaped_image),就会取出一张图片
  return reshaped_image

if __name__ == '__main__':
  # 创建一个会话sess
  with tf.Session() as sess:
    # 调用inputs_origin。cifar10_data/cifar-10-batches-bin是我们下载的数据的文件夹位置
    reshaped_image = inputs_origin('cifar10_data/cifar-10-batches-bin')
    # 这一步start_queue_runner很重要。
    # 我们之前有filename_queue = tf.train.string_input_producer(filenames)
    # 这个queue必须通过start_queue_runners才能启动
    # 缺少start_queue_runners程序将不能执行
    threads = tf.train.start_queue_runners(sess=sess)
    # 变量初始化
    sess.run(tf.global_variables_initializer())
    # 创建文件夹cifar10_data/raw/
    if not os.path.exists('cifar10_data/raw/'):
      os.makedirs('cifar10_data/raw/')
    # 保存30张图片
    for i in range(30):
      # 每次sess.run(reshaped_image),都会取出一张图片
      image_array = sess.run(reshaped_image)
      # 将图片保存
      scipy.misc.toimage(image_array).save('cifar10_data/raw/%d.jpg' % i)

inputs_ origin 是一个函数。这个函数中包含了前两个步骤, tf.train.string_inpu t_pr:oducer 和使用reader 。函数的返回值reshaped_image 是一个Tensor,对应一张训练图像。下面要做的并不是直接运行sess.run(reshaped_image),而是使用threads = tf. train. start_ queue_ runners( sess=sess)。只高调用过tf. train.start_ queue_ runners 后,才会让系统中的所高队列真正地“运行”,开始从文件中读数据。如果不调用这条i吾旬,系统将会一直等待。

最后用sess.run(reshaped_image)取出训练图片并保存。此程序一共在文件夹cifar10data/raw/中保存了30 张图片。读者可以打开该文件夹,看到原始的CIFAR-10 训练图片。再回过头来看inputs_ origin 函数:

def inputs_origin(data_dir):
  # filenames一共5个,从data_batch_1.bin到data_batch_5.bin
  # 读入的都是训练图像
  filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
               for i in xrange(1, 6)]
  # 判断文件是否存在
  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)
  # 将文件名的list包装成TensorFlow中queue的形式
  filename_queue = tf.train.string_input_producer(filenames)
  # cifar10_input.read_cifar10是事先写好的从queue中读取文件的函数
  # 返回的结果read_input的属性uint8image就是图像的Tensor
  read_input = cifar10_input.read_cifar10(filename_queue)
  # 将图片转换为实数形式
  reshaped_image = tf.cast(read_input.uint8image, tf.float32)
  # 返回的reshaped_image是一张图片的tensor
  # 我们应当这样理解reshaped_image:每次使用sess.run(reshaped_image),就会取出一张图片
  return reshaped_image

tf.train.string_input_producer(filenames )创建了一个文件名队列,真中filenames 是一个列表,包含从data_batch_1.bin 到data_batch_5.bin 一共5 个文件名。这正好对应了CIFAR-10 的训练集。cifar10_ input.read_ cifar_10(filename_queue)对应“使用reader ”的步骤。为此需要查看cifar10_input.py中的read cifar10函数,其中关键的代码如下。

# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
label_bytes = 1  # 2 for CIFAR-100
result.height = 32
result.width = 32
result.depth = 3
image_bytes = result.height * result.width * result.depth
# Every record consists of a label followed by the image, with a
# fixed number of bytes for each.
record_bytes = label_bytes + image_bytes

# Read a record, getting filenames from the filename_queue.  No
# header or footer in the CIFAR-10 format, so we leave header_bytes
# and footer_bytes at their default of 0.
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
result.key, value = reader.read(filename_queue)

i吾句tf.FixedLengthRecordReader(record_ bytes=record _bytes )创建了一个reader , 包每次在文件中读取record_bytes 字节的数据,直到文件结束。结合代码, record_bytes 就等于1+32*32*3,即3073 ,正好对应CIFAR-10中一个样本的字节长度。使用reader.read(filename_queue)后, reader 从之前建立好的文件名队列中渎职数据(以Tensor 的形式)。简单处理结果后由函数返回。至此,读者应当对CIFAR-10 数据的读取流程及TensorFlow 的读取机制相当熟悉了。

猜你喜欢

转载自blog.csdn.net/czp_374/article/details/81123578
今日推荐