Tensorflow - 生成批量数据 - 方法3

每次只产生一个batch的数据

# coding: utf-8
from __future__ import print_function

import tensorflow as tf
import random
import numpy as np


class ToySequenceData(object):
    """ 生成序列数据。每个数量可能具有不同的长度。
    一共生成下面两类数据
    - 类别 0: 线性序列 (如 [0, 1, 2, 3,...])
    - 类别 1: 完全随机的序列 (i.e. [1, 3, 10, 7,...])
    注意:
    max_seq_len是最大的序列长度。对于长度小于这个数值的序列,我们将会补0。
    在送入RNN计算时,会借助sequence_length这个属性来进行相应长度的计算。
    """
    def __init__(self, n_samples=1000, max_seq_len=20, min_seq_len=3,
                 max_value=1000):
        self.data = []
        self.labels = []
        self.seqlen = []
        for i in range(n_samples):
            # 序列的长度是随机的,在min_seq_len和max_seq_len之间。
            len = random.randint(min_seq_len, max_seq_len)
            # self.seqlen用于存储所有的序列。    实际的序列长度,不算0
            self.seqlen.append(len)
            # 以50%的概率,随机添加一个线性或随机的训练
            if random.random() < .5:
                # 生成一个线性序列
                rand_start = random.randint(0, max_value - len)
                s = [[float(i)/max_value] for i in range(rand_start, rand_start + len)]
                # 长度不足max_seq_len的需要补0
                s += [[0.] for i in range(max_seq_len - len)]
                self.data.append(s)
                # 线性序列的label是[1, 0](因为我们一共只有两类)
                self.labels.append([1., 0.])
            else:
                # 生成一个随机序列
                s = [[float(random.randint(0, max_value))/max_value] for i in range(len)]
                # 长度不足max_seq_len的需要补0
                s += [[0.] for i in range(max_seq_len - len)]
                self.data.append(s)
                self.labels.append([0., 1.])
        self.batch_id = 0  # batch_id 是全局变量,因此记录了累加值

    def next(self, batch_size):
        """
        生成batch_size的样本。
        如果使用完了所有样本,会重新从头开始。
        """
        if self.batch_id == len(self.data):
            self.batch_id = 0
        batch_data = (self.data[self.batch_id:min(self.batch_id + batch_size, len(self.data))])
        batch_labels = (self.labels[self.batch_id:min(self.batch_id + batch_size, len(self.data))])
        batch_seqlen = (self.seqlen[self.batch_id:min(self.batch_id + batch_size, len(self.data))])
        self.batch_id = min(self.batch_id + batch_size, len(self.data))
        return batch_data, batch_labels, batch_seqlen


# 这一部分只是测试一下如何使用上面定义的ToySequenceData
tmp = ToySequenceData()

# 生成样本
batch_data, batch_labels, batch_seqlen = tmp.next(32)

# batch_data是序列数据,它是一个嵌套的list,形状为(batch_size, max_seq_len, 1)
print(np.array(batch_data).shape)  # (32, 20, 1)

# 我们之前调用tmp.next(32),因此一共有32个序列
# 我们可以打出第一个序列
print(batch_data[0])  # 形如 [[0.084], [0.085].....[0.086], [0.087], [0.088]

# batch_labels是label,它也是一个嵌套的list,形状为(batch_size, 2)
# (batch_size, 2)中的“2”表示为两类分类
print(np.array(batch_labels).shape)  # (32, 2)

# 我们可以打出第一个序列的label
print(batch_labels[0])  # [1.0, 0.0]

# batch_seqlen一个长度为batch_size的list,表示每个序列的实际长度
print(np.array(batch_seqlen).shape)  # (32,)

# 我们可以打出第一个序列的长度
print(batch_seqlen[0])

#  来自tensorflow21项目 第13章

发布了18 篇原创文章 · 获赞 5 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/Zhou_Dao/article/details/103754801