将CIFAR-10数据集保存为图片形式

# -*- coding: utf-8 -*-
"""
Created on Tue Aug  7 20:45:01 2018

@author: lenovo
"""
import cifar10_input
import tensorflow as tf
import os
import scipy.misc

def inputs_origin(data_dir):
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
               for i in range(1, 6)]
    # 判断文件是否存在
    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)
    # 将文件名的list包装成TensorFlow中queue的形式
    filename_queue = tf.train.string_input_producer(filenames)
    read_input = cifar10_input.read_cifar10(filename_queue)
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)
    return reshaped_image

if __name__ == '__main__':
  # 创建一个会话sess
  with tf.Session() as sess:
    # 调用inputs_origin。cifar10_data/cifar-10-batches-bin是我们下载的数据的文件夹位置
    reshaped_image = inputs_origin('C:\\Users\\lenovo\\Desktop\\cifar10_data\\cifar-10-batches-bin')
    # 这一步start_queue_runner很重要。
    # 我们之前有filename_queue = tf.train.string_input_producer(filenames)
    # 这个queue必须通过start_queue_runners才能启动
    # 缺少start_queue_runners程序将不能执行
    threads = tf.train.start_queue_runners(sess=sess)
    # 变量初始化
    sess.run(tf.global_variables_initializer())
    # 创建文件夹cifar10_data/raw/
    if not os.path.exists('cifar10_data/raw/'):
      os.makedirs('cifar10_data/raw/')
    # 保存30张图片
    for i in range(30):
      # 每次sess.run(reshaped_image),都会取出一张图片
      image_array = sess.run(reshaped_image)
      # 将图片保存
      scipy.misc.toimage(image_array).save('cifar10_data/raw/%d.jpg' % i)

猜你喜欢

转载自blog.csdn.net/qq_41858768/article/details/81489697
今日推荐