Tensorflow 2.x:Dataset的repeat、shuffle和batch操作

Shuffle

shuffle(
    buffer_size, seed=None, reshuffle_each_iteration=None
)

This dataset fills a buffer with buffer_size elements, then randomly samples elements from this buffer, replacing the selected elements with new elements. For perfect shuffling, a buffer size greater than or equal to the full size of the dataset is required.

  • reshuffle_each_iteration控制每个epoch的顺序是否不同
>>> dataset = tf.data.Dataset.range(3)
>>> dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
>>> dataset = dataset.repeat(2)  # doctest: +SKIP
[1, 0, 2, 1, 0, 2]

Repeat

repeat(
    count=None
)

Repeats this dataset so each original value is seen count times.

>>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
>>> dataset = dataset.repeat(3)
>>> list(dataset.as_numpy_iterator())
[1, 2, 3, 1, 2, 3, 1, 2, 3]

Batch

batch(
    batch_size, drop_remainder=False
)

Combines consecutive elements of this dataset into batches.

>>> dataset = tf.data.Dataset.range(8)
>>> dataset = dataset.batch(3, drop_remainder=True)
>>> list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5])]

组合

假如,以下是实际数据

sequence = tf.range(10)
dataset = tf.data.Dataset.from_tensor_slices(sequence)
buffer_size = 5
batch_size = 4

1. shuffle -> batch

# 取数原理:取序列前`buffer_size`个元素放入buffer,每次从buffer随机选取`batch_size`个元素
#   buffer中元素被取出后,立即按顺序将下一个元素加入buffer,直至序列元素全部被消耗
for i, batch in enumerate(dataset.shuffle(buffer_size).batch(batch_size)):
    print(batch.numpy())
"""
[4 1 3 5]
[6 2 0 8]
[9 7]
"""

# 对于此例,7不会出现在第一个batch

2. batch -> shuffle

# 取数原理:将序列按`batch_size`切片,将前`buffer_size`分片放入buffer,每次从buffer中选取一个分片
#   buffer中元素/分片被取出后,立即按顺序将下一个分片加入buffer,直至全部分片被消耗
for i, batch in enumerate(dataset.batch(batch_size).shuffle(buffer_size)):
    print(batch.numpy())
"""
[4 5 6 7]
[0 1 2 3]
[8 9]
"""

3. repeat -> shuffle -> batch

# 取数原理:序列首尾相接,后序逻辑与方式1相同,同一batch数据可能重复
# 与shuffle -> repeat -> batch类似
for i, batch in enumerate(dataset.repeat().shuffle(buffer_size).batch(batch_size)):
    print(batch.numpy())
    if i == 4:
        break
"""
[4 3 2 7]
[5 0 9 6]
[0 1 3 1]
[5 6 7 8]
[4 2 9 0]
"""

4. shuffle -> batch -> repeat

# 取数原理:重复方法1,同一batch数据无重复
for i, batch in enumerate(dataset.shuffle(buffer_size).batch(batch_size).repeat()):
    print(batch.numpy())
    if i == 4:
        break
"""
[2 4 6 1]
[3 8 9 7]
[0 5]
[0 2 6 7]
[1 4 9 5]
"""

猜你喜欢

转载自blog.csdn.net/sinat_34072381/article/details/106340942
今日推荐