批量读取数据next_batch()的简单函数实现

读取方法过程

Method()
# 随机取batch_size个训练样本  
import numpy as np

def next_batch(train_data, train_target, batch_size):  
    index = [ i for i in range(0,len(train_target)) ]  
    np.random.shuffle(index);  
    batch_data = []; 
    batch_target = [];  
    for i in range(0,batch_size):  
        batch_data.append(train_data[index[i]]);  
        batch_target.append(train_target[index[i]])  
    return batch_data, batch_target  

解释: 使用改了numpy中的shuffle(index)函数’
numpy.random.shuffle(x)
Modify a sequence in-place by shuffling its contents.

This function only shuffles the array along the first axis of a multi-dimensional array. The order of sub-arrays is changed but their contents remains the same.

>>> arr = np.arange(10)
>>> np.random.shuffle(arr)
>>> arr
[1 7 5 2 9 4 3 6 0 8]

猜你喜欢

转载自blog.csdn.net/gsww404/article/details/80381629