Python循环产生批量数据batch

版权声明:本文为博主原创文章,未经博主允许不得转载(pan_jinquan) https://blog.csdn.net/guyuealian/article/details/83473298

Python循环产生批量数据batch

目录

Python循环产生批量数据batch

一、Python循环产生批量数据batch

二、TensorFlow循环产生批量数据batch 

(1) tf.train.slice_input_producer

(2) tf.train.batch和tf.train.shuffle_batch

(3) TF循环产生批量数据batch 的完整例子


一、Python循环产生批量数据batch

   在机器学习中,经常需要产生一个batch的数据用于训练模型,比如tensorflow的接口tf.train.batch就可以实现数据批量读取的操作。本博客将实现一个类似于tensorflow接口tf.train.batch的方法,循环产生批量数据batch。实现的代码和测试的代码如下:

   TXT文本如下,格式:图片名 label1 label2 ,注意label可以多个

1.jpg 1 11
2.jpg 2 12
3.jpg 3 13
4.jpg 4 14
5.jpg 5 15
6.jpg 6 16
7.jpg 7 17
8.jpg 8 18

    要想产生batch数据,关键是要用到Python的关键字yield,实现一个batch一个batch的返回数据,代码实现主要有两个方法:

def get_data_batch(inputs, batch_size=None, shuffle=False):
    '''
    循环产生批量数据batch
    :param inputs: list数据
    :param batch_size: batch大小
    :param shuffle: 是否打乱inputs数据
    :return: 返回一个batch数据
    '''
def get_next_batch(batch):
    return batch.__next__()

    使用时,将数据传到 get_data_batch( )方法,然后使用get_next_batch( )获得一个batch数据,完整的Python代码如下:

# -*-coding: utf-8 -*-
"""
    @Project: create_batch_data
    @File   : create_batch_data.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2017-10-27 18:20:15
"""
import math
import random
import os
import glob
import numpy as np

def get_data_batch(inputs, batch_size=None, shuffle=False):
    '''
    循环产生批量数据batch
    :param inputs: list数据
    :param batch_size: batch大小
    :param shuffle: 是否打乱inputs数据
    :return: 返回一个batch数据
    '''
    # rows,cols=inputs.shape
    rows=len(inputs[0])
    indices =list(range(rows))
    if shuffle:
        random.shuffle(indices )
    while True:
        batch_indices = indices[0:batch_size] # 产生一个batch的index
        indices= indices [batch_size:] + indices[:batch_size]  # 循环移位,以便产生下一个batch
        batch_data=[]
        for data in inputs:
            temp_data=find_list(batch_indices,data)
            batch_data.append(temp_data)
        yield batch_data

def find_list(indices,data):
    out=[]
    for i in indices:
        out=out+[data[i]]
    return out
def get_list_batch(inputs, batch_size=None, shuffle=False):
    '''
    循环产生batch数据
    :param inputs: list数据
    :param batch_size: batch大小
    :param shuffle: 是否打乱inputs数据
    :return: 返回一个batch数据
    '''
    if shuffle:
        random.shuffle(inputs)
    while True:
        batch_inouts = inputs[0:batch_size]
        inputs=inputs[batch_size:] + inputs[:batch_size]# 循环移位,以便产生下一个batch
        yield batch_inouts

def load_file_list(text_dir):
    text_dir = os.path.join(text_dir, '*.txt')
    text_list = glob.glob(text_dir)
    return text_list

def get_next_batch(batch):
    return batch.__next__()

def load_image_labels(finename):
    '''
    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签1,如:test_image/1.jpg 0 2
    :param test_files:
    :return:
    '''
    images_list=[]
    labels_list=[]
    with open(finename) as f:
        lines = f.readlines()
        for line in lines:
            #rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
            content=line.rstrip().split(' ')
            name=content[0]
            labels=[]
            for value in content[1:]:
                labels.append(float(value))
            images_list.append(name)
            labels_list.append(labels)
    return images_list,labels_list

if __name__ == '__main__':
    filename='./training_data/train.txt'
    images_list, labels_list=load_image_labels(filename)

    # 若输入为np.arange数组,则需要tolist()为list类型,如:
    # images_list = np.reshape(np.arange(8*3), (8,3))
    # labels_list = np.reshape(np.arange(8*3), (8,3))
    # images_list=images_list.tolist()
    # labels_list=labels_list.tolist()

    iter = 5# 迭代5次,每次输出一个batch个
    batch = get_data_batch([images_list, labels_list], batch_size=3, shuffle=False)
    for i in range(iter):
        print('**************************')
        # train_batch=batch.__next__()
        batch_images,batch_labels=get_next_batch(batch)
        print('batch_images:{}'.format(batch_images))
        print('batch_labels:{}'.format(batch_labels))


   运行输出结果为:

**************************
batch_images:['1.jpg', '2.jpg', '3.jpg']
batch_labels:[[1.0, 11.0], [2.0, 12.0], [3.0, 13.0]]
**************************
batch_images:['4.jpg', '5.jpg', '6.jpg']
batch_labels:[[4.0, 14.0], [5.0, 15.0], [6.0, 16.0]]
**************************
batch_images:['7.jpg', '8.jpg', '1.jpg']
batch_labels:[[7.0, 17.0], [8.0, 18.0], [1.0, 11.0]]
**************************
batch_images:['2.jpg', '3.jpg', '4.jpg']
batch_labels:[[2.0, 12.0], [3.0, 13.0], [4.0, 14.0]]
**************************
batch_images:['5.jpg', '6.jpg', '7.jpg']
batch_labels:[[5.0, 15.0], [6.0, 16.0], [7.0, 17.0]]

Process finished with exit code 0

二、TensorFlow循环产生批量数据batch 

    使用TensorFlow实现产生批量数据batch,需要几个接口,

(1) tf.train.slice_input_producer

tf.train.slice_input_producer是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。

slice_input_producer(tensor_list,
                     num_epochs=None,
                     shuffle=True,
                     seed=None,
                     capacity=32,
                     shared_name=None,
                     name=None)
# 第一个参数
#           tensor_list:包含一系列tensor的列表,表中tensor的第一维度的值必须相等,即个数必须相等,有多少个图像,就应该有多少个对应的标签。
# 第二个参数num_epochs: 可选参数,是一个整数值,代表迭代的次数,如果设置
#           num_epochs = None, 生成器可以无限次遍历tensor列表,如果设置为
#           num_epochs = N,生成器只能遍历tensor列表N次。
# 第三个参数shuffle: bool类型,设置是否打乱样本的顺序。一般情况下,如果shuffle = True,生成的样本顺序就被打乱了,在批处理的时候不需要再次打乱样本,使用
#           tf.train.batch函数就可以了;
#           如果shuffle = False, 就需要在批处理时候使用
#           tf.train.shuffle_batch函数打乱样本。
# 第四个参数seed: 可选的整数,是生成随机数的种子,在第三个参数设置为shuffle = True的情况下才有用。
# 第五个参数capacity:设置tensor列表的容量。
# 第六个参数shared_name:可选参数,如果设置一个‘shared_name’,则在不同的上下文环境(Session)中可以通过这个名字共享生成的tensor。
# 第七个参数name:可选,设置操作的名称

    tf.train.slice_input_producer定义了样本放入文件名队列的方式,包括迭代次数,是否乱序等,要真正将文件放入文件名队列,还需要调用tf.train.start_queue_runners 函数来启动执行文件名队列填充的线程,之后计算单元才可以把数据读出来,否则文件名队列为空的,计算单元就会处于一直等待状态,导致系统阻塞。

    例子:

import tensorflow as tf
 
images = ['img1', 'img2', 'img3', 'img4', 'img5']
labels= [1,2,3,4,5]
 
epoch_num=8
 
f = tf.train.slice_input_producer([images, labels],num_epochs=None,shuffle=False)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(epoch_num):
        k = sess.run(f)
        print '************************'
        print (i,k)
 
    coord.request_stop()
    coord.join(threads)

(2) tf.train.batch和tf.train.shuffle_batch

    tf.train.batch是一个tensor队列生成器,作用是按照给定的tensor顺序,把batch_size个tensor推送到文件队列,作为训练一个batch的数据,等待tensor出队执行计算。

tf.train.batch(tensors, 
               batch_size, 
               num_threads=1, 
               capacity=32,
               enqueue_many=False, 
               shapes=None, 
               dynamic_pad=False,
               allow_smaller_final_batch=False, 
               shared_name=None, 
               name=None)
# 第一个参数tensors:tensor序列或tensor字典,可以是含有单个样本的序列;
# 第二个参数batch_size: 生成的batch的大小;
# 第三个参数num_threads:执行tensor入队操作的线程数量,可以设置使用多个线程同时并行执行,提高运行效率,但也不是数量越多越好;
# 第四个参数capacity: 定义生成的tensor序列的最大容量;
# 第五个参数enqueue_many: 定义第一个传入参数tensors是多个tensor组成的序列,还是单个tensor;
# 第六个参数shapes: 可选参数,默认是推测出的传入的tensor的形状;
# 第七个参数dynamic_pad: 定义是否允许输入的tensors具有不同的形状,设置为True,会把输入的具有不同形状的tensor归一化到相同的形状;
# 第八个参数allow_smaller_final_batch: 设置为True,表示在tensor队列中剩下的tensor数量不够一个batch_size的情况下,允许最后一个batch的数量少于batch_size, 设置为False,则不管什么情况下,生成的batch都拥有batch_size个样本;
# 第九个参数shared_name: 可选参数,设置生成的tensor序列在不同的Session中的共享名称;
# 第十个参数name: 操作的名称;

    如果tf.train.batch的第一个参数 tensors 传入的是tenor列表或者字典,返回的是tensor列表或字典,如果传入的是只含有一个元素的列表,返回的是单个的tensor,而不是一个列表。

    与tf.train.batch函数相对的还有一个tf.train.shuffle_batch函数,两个函数作用一样,都是生成一定数量的tensor,组成训练一个batch需要的数据集,区别是tf.train.shuffle_batch会打乱样本顺序。

(3) TF循环产生批量数据batch 的完整例子

# -*-coding: utf-8 -*-
"""
    @Project: LSTM
    @File   : tf_create_batch_data.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2018-10-28 17:50:24
"""
import tensorflow as tf


def get_data_batch(inputs,batch_size,labels_nums,one_hot=False,shuffle=False,num_threads=1):
    '''
    :param inputs: 输入数据,可以是多个list
    :param batch_size:
    :param labels_nums:标签个数
    :param one_hot:是否将labels转为one_hot的形式
    :param shuffle:是否打乱顺序,一般train时shuffle=True,验证时shuffle=False
    :return:返回batch的images和labels
    '''
    # 生成队列
    inputs_que= tf.train.slice_input_producer(inputs, shuffle=shuffle)
    min_after_dequeue = 200
    capacity = min_after_dequeue + 3 * batch_size  # 保证capacity必须大于min_after_dequeue参数值
    if shuffle:
        out_batch = tf.train.shuffle_batch(inputs_que,
                                            batch_size=batch_size,
                                            capacity=capacity,
                                            min_after_dequeue=min_after_dequeue,
                                            num_threads=num_threads)
    else:
        out_batch = tf.train.batch(inputs_que,
                                    batch_size=batch_size,
                                    capacity=capacity,
                                    num_threads=num_threads)
    return out_batch

def get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=False):
    '''
    :param images:图像
    :param labels:标签
    :param batch_size:
    :param labels_nums:标签个数
    :param one_hot:是否将labels转为one_hot的形式
    :param shuffle:是否打乱顺序,一般train时shuffle=True,验证时shuffle=False
    :return:返回batch的images和labels
    '''
    images_que, labels_que= tf.train.slice_input_producer([images,labels], shuffle=shuffle)
    min_after_dequeue = 200
    capacity = min_after_dequeue + 3 * batch_size  # 保证capacity必须大于min_after_dequeue参数值
    if shuffle:
        images_batch, labels_batch = tf.train.shuffle_batch([images_que, labels_que],
                                                            batch_size=batch_size,
                                                            capacity=capacity,
                                                            min_after_dequeue=min_after_dequeue)
    else:
        images_batch, labels_batch = tf.train.batch([images_que, labels_que],
                                                    batch_size=batch_size,
                                                    capacity=capacity)
    if one_hot:
        labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)
    return images_batch,labels_batch

def load_image_labels(finename):
    '''
    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签1,如:test_image/1.jpg 0 2
    :param test_files:
    :return:
    '''
    images_list=[]
    labels_list=[]
    with open(finename) as f:
        lines = f.readlines()
        for line in lines:
            #rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
            content=line.rstrip().split(' ')
            name=content[0]
            labels=[]
            for value in content[1:]:
                labels.append(float(value))
            images_list.append(name)
            labels_list.append(labels)
    return images_list,labels_list

if __name__ == '__main__':
    filename='./training_data/train.txt'
    # 输入数据可以是list,也可以是np.array
    images_list, labels_list=load_image_labels(filename)
    # np.arange数组如:
    # images_list = np.reshape(np.arange(8*3), (8,3))
    # labels_list = np.reshape(np.arange(8*3), (8,3))

    iter = 5  # 迭代5次,每次输出一个batch个
    # batch_images, batch_labels = get_data_batch( inputs=[images_list, labels_list],batch_size=3,labels_nums=2,one_hot=False,shuffle=False,num_threads=1)
    # 或者
    batch_images, batch_labels = get_batch_images(images_list,labels_list,batch_size=3,labels_nums=2,one_hot=False,shuffle=False)
    with tf.Session() as sess:  # 开始一个会话
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(iter):
            # 在会话中取出images和labels
            images, labels = sess.run([batch_images, batch_labels] )
            print('**************************')
            print('batch_images:{}'.format(images ))
            print('batch_labels:{}'.format(labels))

        # 停止所有线程
        coord.request_stop()
        coord.join(threads)

    运行输出结果:

**************************
batch_images:[b'1.jpg' b'2.jpg' b'3.jpg']
batch_labels:[[ 1. 11.] [ 2. 12.][ 3. 13.]]
**************************
batch_images:[b'4.jpg' b'5.jpg' b'6.jpg']
batch_labels:[[ 4. 14.] [ 5. 15.][ 6. 16.]]
**************************
batch_images:[b'7.jpg' b'8.jpg' b'1.jpg']
batch_labels:[[ 7. 17.][ 8. 18.][ 1. 11.]]
**************************
batch_images:[b'2.jpg' b'3.jpg' b'4.jpg']
batch_labels:[[ 2. 12.] [ 3. 13.][ 4. 14.]]
**************************
batch_images:[b'5.jpg' b'6.jpg' b'7.jpg']
batch_labels:[[ 5. 15.][ 6. 16.][ 7. 17.]]

猜你喜欢

转载自blog.csdn.net/guyuealian/article/details/83473298