【深度学习】打乱数据及keras自定义generator

版权声明:本人原创,转载请注明出处 https://blog.csdn.net/zjn295771349/article/details/86616979

一、打乱数据

在深度学习中,打乱数据是很重要的。比如,训练集、验证集和测试集需要来自同一分布,所以要打乱数据集再分离,这样就能保证训练集、验证集和测试集的数据分布都是相同的。再比如,制作minibatch的时候,每经过一次epoch都要打乱一次数据集,使每次输入的minibatch分布都不相同,可见打乱数据的重要性。

假如你的内存能装下整个数据集那么,就可以这样:

import numpy as np
#以样本数量生成乱序的list
permutation = list(np.random.permutation(x.shape[0]))
#按照随机生成的顺序重新排列数据集
x = x[permutation,:,:,:]#训练集有4个维度,第一个维度指的是样本数量
y = y[permutation,:]

二、keras自定义generator

当你的内存装不下整个数据集,那么可以把数据集做成.h5文件,然后每次制作minibatch的时候从硬盘中读取。那么就可以自己写个generator。

import h5py
import numpy as np

train_dataset = h5py.File('trainingset_224x224.h5', "r")

def generate_train_from_file(train_dataset, m, batch_size=64):
    #m指的是样本数量
    num_minibatches = int(m/batch_size)
    while 1:
        permutation = list(np.random.permutation(m))
        for i in range(num_minibatches):
            index = permutation[i*batch_size:(i+1)*batch_size] 
            #从.h5文件读取数据时,索引必须是从小到大排列的,所以读取之前把索引排个序
            index.sort()
            minibatches_x = np.array(train_dataset['trainingset_x'][index])
            minibatches_y = np.array(train_dataset['trainingset_y'][index])

            yield minibatches_x, minibatches_y

#这么调用就可以了
model.fit_generator(generator = generate_train_from_file(train_dataset, m, batch_size=batch_size),...)
                    

注意,在函数中需要用while写成死循环,因为每个epoch不会重新调用方法。当函数以yield关键词返回,那么这个函数则是个生成器,生成器指的是当它返回数据之后再次执行时再从这个地方继续执行。

猜你喜欢

转载自blog.csdn.net/zjn295771349/article/details/86616979