tensorflow学习2:tf.train.slice_input_producer 和 tf.train.batch 函数

转载至:https://blog.csdn.net/dcrmg/article/details/79776876

tensorflow数据读取机制

tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算。

具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程负责计算任务,所需数据直接从内存队列中获取。

tf在内存队列之前,还设立了一个文件名队列,文件名队列存放的是参与训练的文件名,示例图如下:

tf.train.slice_input_producer函数

对于文件名队列,我们使用tf.train.string_input_producer函数。这个函数需要传入一个文件名list,系统会自动将它转为一个文件名队列。tf.train.string_input_producer还有两个重要的参数,一个是num_epochs,表示epoch数。另外一个就是shuffle是指在一个epoch内文件的顺序是否被打乱。

def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
                         capacity=32, shared_name=None, name=None)
  • 第一个参数 tensor_list:包含一系列tensor的列表,表中tensor的第一维度的值必须相等,即个数必须相等,有多少个图像,就应该有多少个对应的标签。
  • 第二个参数num_epochs: 可选参数,是一个整数值,代表迭代的次数,如果设置 num_epochs=None,生成器可以无限次遍历tensor列表,如果设置为 num_epochs=N,生成器只能遍历tensor列表N次。
  • 第三个参数shuffle: bool类型,设置是否打乱样本的顺序。一般情况下,如果shuffle=True,生成的样本顺序就被打乱了,在批处理的时候不需要再次打乱样本,使用 tf.train.batch函数就可以了;如果shuffle=False,就需要在批处理时候使用 tf.train.shuffle_batch函数打乱样本。
  • 第四个参数seed: 可选的整数,是生成随机数的种子,在第三个参数设置为shuffle=True的情况下才有用。
  • 第五个参数capacity:设置tensor列表的容量。
  • 第六个参数shared_name:可选参数,如果设置一个‘shared_name’,则在不同的上下文环境(Session)中可以通过这个名字共享生成的tensor。
  • 第七个参数name:可选,设置操作的名称。

tf.train.slice_input_producer定义了样本放入文件名队列的方式,包括迭代次数,是否乱序等,要真正将文件放入文件名队列,还需要调用tf.train.start_queue_runners 函数来启动执行文件名队列填充的线程,之后计算单元才可以把数据读出来,否则文件名队列为空的,计算单元就会处于一直等待状态,导致系统阻塞。

样例代码如下:

import tensorflow as tf
 
images = ['img1', 'img2', 'img3', 'img4', 'img5']
labels= [1,2,3,4,5]
 
epoch_num=8
 
f = tf.train.slice_input_producer([images, labels],num_epochs=None,shuffle=False)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(epoch_num):
        k = sess.run(f)
        print ('************************')
        print (i,k)
 
    coord.request_stop()
    coord.join(threads)

tf.train.batch函数

tf.train.batch是一个tensor队列生成器,作用是按照给定的tensor顺序,把batch_size个tensor推送到文件队列,作为训练一个batch的数据,等待tensor出队执行计算。

def batch(tensors, batch_size, num_threads=1, capacity=32,
          enqueue_many=False, shapes=None, dynamic_pad=False,
          allow_smaller_final_batch=False, shared_name=None, name=None)
  • 第一个参数tensors:tensor序列或tensor字典,可以是含有单个样本的序列;
  • 第二个参数batch_size: 生成的batch的大小;
  • 第三个参数num_threads:执行tensor入队操作的线程数量,可以设置使用多个线程同时并行执行,提高运行效率,但也不是数量越多越好;
  • 第四个参数capacity: 定义生成的tensor序列的最大容量;
  • 第五个参数enqueue_many: 定义第一个传入参数tensors是多个tensor组成的序列,还是单个tensor;
  • 第六个参数shapes: 可选参数,默认是推测出的传入的tensor的形状;
  • 第七个参数dynamic_pad: 定义是否允许输入的tensors具有不同的形状,设置为True,会把输入的具有不同形状的tensor归一化到相同的形状;
  • 第八个参数allow_smaller_final_batch: 设置为True,表示在tensor队列中剩下的tensor数量不够一个batch_size的情况下,允许最后一个batch的数量少于batch_size, 设置为False,则不管什么情况下,生成的batch都拥有batch_size个样本;
  • 第九个参数shared_name: 可选参数,设置生成的tensor序列在不同的Session中的共享名称;
  • 第十个参数name: 操作的名称;

下面给出tensorflow读取数据生成batch的一段简单的代码例子,其中所有的图片放置在train_dir中,为了方便运行我只放置了10张样本,batch_size设置为3,则共有4个batch。tf.train.slice_input_producer中shuttle参数采用False,即按照顺序出队,更方便于我们认识整个过程。num_epochs=None,生成器可以无限次遍历tensor列表。

import tensorflow as tf
import numpy as np  
import os
import matplotlib.pyplot as plt


def show_image(image):
    #显示单张图片
    image = np.transpose(image,(2,0,1)) #需要改变一下维度
    plt.imshow(image[0],cmap='Greys_r')                                  
    plt.show()

def show_batch_image(batch_image):
    #显示一个batch的图片
    for i in range(batch_size):
        image = batch_image[i]
        show_image(image)
           
image_W = 64
image_H = 64
sample_num = 10
epoch_num = 3
batch_size = 3
batch_total = int(sample_num/batch_size)+1 #一个epoch有的batch数目

train_dir = 'C:\\Users\\wyt\\Desktop\\test\\ship'
image_list = []
for file in os.listdir(train_dir):  
    image_list.append(os.path.join(train_dir, file))  


input_queue = tf.train.slice_input_producer([image_list],shuffle = False)  #生成文件名队列,由于shuffle = True,随机产生一个文件名路径
image_contents = tf.read_file(input_queue[0])               #根据上面的文件名路径,读取样本(处于编码下)
image = tf.image.decode_jpeg(image_contents, channels=1)    #进行解码,由于读取的为灰度图,channels=1
#crop_image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H) #采用切割的方式改变大小
crop_image = tf.image.resize_images(image, [image_W,image_H],method=0)  #采用放缩的方式改变大小,method=0为双线性插值
standard_image = tf.image.per_image_standardization(crop_image)   # 标准化数据
image_batch = tf.train.batch([crop_image], batch_size=batch_size, num_threads=2, capacity=64, allow_smaller_final_batch=False) #生成batch

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)
    try:
        for i in range(epoch_num):  # 每一轮迭代
            print ('epoch is %d'%(i+1))
            for j in range(batch_total): #每一个batch
                #print ('batch num is %d'%(j+1))
                #print ('input_queue:%s'%(sess.run(input_queue)))     
                #print ('image_contents:%s'%(sess.run(image_contents)))    
                #picture, crop_picture,standard_picture = sess.run([image,crop_image,standard_image])
                #show_image(picture)
                #show_image(crop_picture)
                #show_image(standard_picture)
                picture_batch = sess.run(image_batch)
                print(picture_batch.shape)
                show_batch_image(picture_batch )
    except tf.errors.OutOfRangeError:  #如果读取到文件队列末尾会抛出此异常
        print("done! now lets kill all the threads……")
    finally:
        # 协调器coord发出所有线程终止信号
        coord.request_stop()
        print('all threads are asked to stop!')
    coord.join(threads) #把开启的线程加入主线程,等待

我的10张样本图如下

运行上段代码后,4个batch显示的图片依次为: 

可以看出,从上至下,从左至右是按照顺序。经过3个batch后由于只剩一下一张样本没有出队,因而会在按顺序从第一张图补进。

猜你喜欢

转载自blog.csdn.net/Mr_health/article/details/81265912