【tf.keras.utils.Sequence】构建自己的数据集生成器

every blog every motto: You can do more than you think.

0. 前言

在训练模型时,我们往往不一次将数据全部加载进内存中,而是将数据分批次加载到内存中。


  • 一种方法是用 while True 遍历数据,用yeid产生,具体可参考语义分割代码讲解部分
  • 另一种方法是本文即将讲解的tf.keras.utils.Sequence方法

1. 正文

__ len __ 中返回的即1个epoch迭代的次数,即:
总样本数/ batch_size

__ getitem __ 根据len中的迭代次数,生成数据


注意: __ len __ ,__ getitem __ 必须要实现

"""
测试
__getitem__
"""
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf


class Date(tf.keras.utils.Sequence):

    def __init__(self):
        print('初始化相关参数')

    def __len__(self):
        """
        此方法要实现,否则会报错
        正常程序中返回1个epoch迭代的次数
        :return:
        """
        return 5

    def __getitem__(self, index):
        """生成一个batch的数据"""
        print('index:', index)
        x_batch = ['x1', 'x2', 'x3', 'x4']
        y_batch = ['y1', 'y2', 'y3', 'y4']
        print('-'*20)
        return x_batch, y_batch


# 实例化数据
date = Date()

for batch_number, (x, y) in enumerate(date):
    print('正在进行第{} batch'.format(batch_number))
    print('x_batch:', x)
    print('y_batcxh:', y)

结果:
在这里插入图片描述

参考文献

[1] https://blog.csdn.net/weixin_39190382/article/details/105808830
[2] https://blog.csdn.net/weixin_43198141/article/details/89926262
[3] https://blog.csdn.net/u011311291/article/details/80991330

猜你喜欢

转载自blog.csdn.net/weixin_39190382/article/details/109195031