TensorFlow中读取数据的方法及其优缺点(一)

使用queue读硬盘中的数据,详细过程可以参考:https://zhuanlan.zhihu.com/p/27238630

【(占坑)关于文件读取部分,列出胡所有文件的文件名】

以上图片为例。用queue方法读取为batch。

主要函数是tf.train.string_input_producer,现已更新为“tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)”。主要用法是:

tf.string_input_producer(string_tensor, num_epochs=None, 
                      shuffle=True, seed=None, capacity=32,
                      shared_name=None, name=None, cancel_op=None)
  • string_tensor:一维的tensor,由需要读取的文件的文件名组成
  • num_epochs:一个epoch代表将数据集的中数据全部使用一遍
  • shuffle:是否打乱数据集,True表示随机打乱数据集
  • seed:随机数种子,用于shuffle

完整的代码如下:

tf.reset_default_graph()

# TODO 读取大量数据时,用os中的方法读取
filename = ['01.jpg','02.jpg','03.jpg','04.jpg','05.jpg','06.jpg']

# !!!如果不设置 num_epochs 的数量,则文件队列是无限循环的,没有结束标志,
# 即不会抛出OutOfRangeError的错误,程序会一直执行下去! 如果没有抛出OutOfRangeError,
# 最坑的是即使,你在try部分里发出所有线程终止信号,程序依然无法终止,只有抛出了OutOfRangeError
# 的错误,所有线程才会终止,否则会报错RuntimeError: Coordinator stopped with threads 
# still running:
filename_queue = tf.train.string_input_producer(filename,shuffle=True,
                                                seed=10,num_epochs=2)

# reader从文件名队列中读数据。对应的方法是reader.read
reader = tf.WholeFileReader()
key,value = reader.read(filename_queue)

with tf.Session() as sess:
    #tf.train.string_input_producer定义了一个epoch变量,它是local的,要对它进行初始化
    tf.local_variables_initializer().run()
    
    # 开启一个协调器
    coord = tf.train.Coordinator()
    # 使用start_queue_runners启动队列填充
    threads = tf.train.start_queue_runners(sess=sess)
    
    i = 0
    try:
        while not coord.should_stop():
            while i <= 20:
                i += 1
                image = sess.run(value)
                with open('data/produce_%d.jpg' % i,'wb') as f:
                    f.write(image)
            
    except tf.errors.OutOfRangeError: #读取完列队中的数据会抛出这个错误
        print('All data have been Readed')
    finally:
        # 协调器coord发出所有线程终止信号
        coord.request_stop()
        print('All threads stoped')
        
    # 把开启的线程加入主线程,等待threads结束,(不懂啥意思)
    coord.join(threads)

注意:# !!!如果不设置 num_epochs 的数量,则文件队列是无限循环的,没有结束标志,即不会抛出OutOfRangeError的错误,程序会一直执行下去! 如果没有抛出OutOfRangeError,最坑的是即使,你在try部分里发出所有线程终止信号,程序依然无法终止,只有抛出了OutOfRangeError的错误,所有线程才会终止,否则会报错RuntimeError: Coordinator stopped with threads still running: 水平有限暂时还不知道怎么解决。

最终生成图片的结果如下:

该方法的问题:

1、每次只能抛出一个样本,而不是一个batch,生成batch需要用到,tf.train.batch()

2、读取出来的value的值是一个字符串,用reader.read()进行编码的,并不能直接使用。

3、如果读取的图片是带有标签的,则标签部分在shuffle之后怎么解决(考虑tf.train.slice_input_producer)

tf.train.slice_input_producer()

tf.train.slice_input_producer是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。

    slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,  
                             capacity=32, shared_name=None, name=None)  
  • tensor_list:包含一系列tensor的列表,表中tensor的第一维度的值必须相等,即个数必须相等,有多少个图像,就应该有多少个对应的标签。
  • num_epochs: 可选参数,是一个整数值,代表迭代的次数,如果设置 num_epochs=None,生成器可以无限次遍历tensor列表,如果设置为 num_epochs=N,生成器只能遍历tensor列表N次。
  • shuffle: bool类型,设置是否打乱样本的顺序。一般情况下,如果shuffle=True,生成的样本顺序就被打乱了,在批处理的时候不需要再次打乱样本,使用 tf.train.batch函数就可以了;如果shuffle=False,就需要在批处理时候使用 tf.train.shuffle_batch函数打乱样本。
  • seed: 可选的整数,是生成随机数的种子,在第三个参数设置为shuffle=True的情况下才有用。

tf.train.batch()

tf.train.batch是一个tensor队列生成器,作用是按照给定的tensor顺序,把batch_size个tensor推送到文件队列,作为训练一个batch的数据,等待tensor出队执行计算。

    batch(tensors, batch_size, num_threads=1, capacity=32,  
              enqueue_many=False, shapes=None, dynamic_pad=False,  
              allow_smaller_final_batch=False, shared_name=None, name=None)  
  • tensors:tensor序列或tensor字典,可以是含有单个样本的序列;
  • batch_size: 生成的batch的大小;
  • num_threads:执行tensor入队操作的线程数量,可以设置使用多个线程同时并行执行,提高运行效率,但也不是数量越多越好;
  • capacity: 定义生成的tensor序列的最大容量;
  • enqueue_many: 定义第一个传入参数tensors是多个tensor组成的序列,还是单个tensor;
  • shapes: 可选参数,默认是推测出的传入的tensor的形状;
  • dynamic_pad: 定义是否允许输入的tensors具有不同的形状,设置为True,会把输入的具有不同形状的tensor归一化到相同的形状;
  • allow_smaller_final_batch: 设置为True,表示在tensor队列中剩下的tensor数量不够一个batch_size的情况下,允许最后一个batch的数量少于batch_size, 设置为False,则不管什么情况下,生成的batch都拥有batch_size个样本;
  • shared_name: 可选参数,设置生成的tensor序列在不同的Session中的共享名称;
  • name: 操作的名称;

第三个问题,可以添加每个image对应的label标签“labels = [0,0,0,1,1,1]”,然后根据key中保存的对应于image的图片的名字,在labels列表里找到label,即“label = labels[filename.index(key_.decode())]”,解决

第二个问题,用“tf.image.decode_jpeg”解码,最好写在session外面;

第一个问题,tf.train.batch()的使用,不好意思用不会,一直报错“tf.python.framework.errors_impl.InvalidArgumentError”捕捉错误以后忽略还是不行,但是由于“tf.WholeFileReader().read()”本来就是一次读一个,所以“tf.train.batch()”这个东西应该还是读出数据以后组合起来的,于是可以考虑自己动手设置batch。然后我就想到了个傻X方法……

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 26 15:24:34 2019

@author: Fj
"""

import tensorflow as tf
import matplotlib.pyplot as plt

tf.reset_default_graph()

# TODO 读取大量数据时,用os中的方法读取
filename = ['01.jpg','02.jpg','03.jpg','04.jpg','05.jpg','06.jpg']
labels = [0,0,0,1,1,1] # 对应的label = [0,0,0,1,1,1] 0为猫,1为狗

batch_size = 32 # batch_size < len(filename)*num_epochs
train_step = 4 # train_step*batch_size > len(filename)*num_epochs 这样才能抛出OutOfRangeError

# !!!如果不设置 num_epochs 的数量,则文件队列是无限循环的,没有结束标志,
# 即不会抛出OutOfRangeError的错误,程序会一直执行下去! 如果没有抛出OutOfRangeError,
# 最坑的是即使,你在try部分里发出所有线程终止信号,程序依然无法终止,只有抛出了OutOfRangeError
# 的错误,所有线程才会终止,否则会报错RuntimeError: Coordinator stopped with threads 
# still running:
filename_queue = tf.train.string_input_producer(filename,shuffle=True,
                                                seed=10,num_epochs=20)

# reader从文件名队列中读数据。对应的方法是reader.read
reader = tf.WholeFileReader()
key,value = reader.read(filename_queue)

image = tf.image.decode_jpeg(value) # 原图为什么格式就decode为什么格式
# image.set_shape([224,224,3]) # 从这个地方统一图片的大小,不写的话就是原图尺寸

with tf.Session() as sess:
    #tf.train.string_input_producer定义了一个epoch变量,它是local的,要对它进行初始化
    tf.local_variables_initializer().run()
    
    # 开启一个协调器
    coord = tf.train.Coordinator()
    # 使用start_queue_runners启动队列填充
    threads = tf.train.start_queue_runners(sess=sess)

    
    try:
        for i in range(train_step):
            
            # 生成一个batch
            batch_img = []
            batch_label = []
            j=0
            while j<batch_size:
                idx,img = sess.run([key,image])
                label = labels[filename.index(idx.decode())]
                batch_img.append(img)
                batch_label.append(label)
                j+=1
            
            print(batch_label) 
        """
        ===============
        """
                
    except tf.errors.OutOfRangeError: #读取完列队中的数据会抛出这个错误
        print('All data have been Readed')
    finally:
        # 协调器coord发出所有线程终止信号
        coord.request_stop()
        print('All threads stoped')
        
    # 把开启的线程加入主线程,等待threads结束,(不懂啥意思)
    coord.join(threads)

    plt.imshow(img)

然后再print(batch_label) 那个地方训练就可以了。

希望有大佬来指点一下…… 然后再去看tf.data的用法

猜你喜欢

转载自blog.csdn.net/Huang_Fj/article/details/101445890