批量读取数据next_batch()的理解

批量读取数据

# 随机取batch_size个训练样本  
import numpy as np
#train_data训练集特征,train_target训练集对应的标签,batch_size
def next_batch(train_data, train_target, batch_size):  
    #打乱数据集
    index = [ i for i in range(0,len(train_target)) ]  
    np.random.shuffle(index);  
    #建立batch_data与batch_target的空列表
    batch_data = []; 
    batch_target = [];  
    #向空列表加入训练集及标签
    for i in range(0,batch_size):  
        batch_data.append(train_data[index[i]]);  
        batch_target.append(train_target[index[i]])  
    return batch_data, batch_target #返回 

猜你喜欢

转载自blog.csdn.net/qq_33373858/article/details/83012236