Shuffle
shuffle(
buffer_size, seed=None, reshuffle_each_iteration=None
)
This dataset fills a buffer with buffer_size elements, then randomly samples elements from this buffer, replacing the selected elements with new elements. For perfect shuffling, a buffer size greater than or equal to the full size of the dataset is required.
- reshuffle_each_iteration控制每个epoch的顺序是否不同
>>> dataset = tf.data.Dataset.range(3)
>>> dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
>>> dataset = dataset.repeat(2) # doctest: +SKIP
[1, 0, 2, 1, 0, 2]
Repeat
repeat(
count=None
)
Repeats this dataset so each original value is seen count times.
>>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
>>> dataset = dataset.repeat(3)
>>> list(dataset.as_numpy_iterator())
[1, 2, 3, 1, 2, 3, 1, 2, 3]
Batch
batch(
batch_size, drop_remainder=False
)
Combines consecutive elements of this dataset into batches.
>>> dataset = tf.data.Dataset.range(8)
>>> dataset = dataset.batch(3, drop_remainder=True)
>>> list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5])]
组合
假如,以下是实际数据
sequence = tf.range(10)
dataset = tf.data.Dataset.from_tensor_slices(sequence)
buffer_size = 5
batch_size = 4
1. shuffle -> batch
# 取数原理:取序列前`buffer_size`个元素放入buffer,每次从buffer随机选取`batch_size`个元素
# buffer中元素被取出后,立即按顺序将下一个元素加入buffer,直至序列元素全部被消耗
for i, batch in enumerate(dataset.shuffle(buffer_size).batch(batch_size)):
print(batch.numpy())
"""
[4 1 3 5]
[6 2 0 8]
[9 7]
"""
# 对于此例,7不会出现在第一个batch
2. batch -> shuffle
# 取数原理:将序列按`batch_size`切片,将前`buffer_size`分片放入buffer,每次从buffer中选取一个分片
# buffer中元素/分片被取出后,立即按顺序将下一个分片加入buffer,直至全部分片被消耗
for i, batch in enumerate(dataset.batch(batch_size).shuffle(buffer_size)):
print(batch.numpy())
"""
[4 5 6 7]
[0 1 2 3]
[8 9]
"""
3. repeat -> shuffle -> batch
# 取数原理:序列首尾相接,后序逻辑与方式1相同,同一batch数据可能重复
# 与shuffle -> repeat -> batch类似
for i, batch in enumerate(dataset.repeat().shuffle(buffer_size).batch(batch_size)):
print(batch.numpy())
if i == 4:
break
"""
[4 3 2 7]
[5 0 9 6]
[0 1 3 1]
[5 6 7 8]
[4 2 9 0]
"""
4. shuffle -> batch -> repeat
# 取数原理:重复方法1,同一batch数据无重复
for i, batch in enumerate(dataset.shuffle(buffer_size).batch(batch_size).repeat()):
print(batch.numpy())
if i == 4:
break
"""
[2 4 6 1]
[3 8 9 7]
[0 5]
[0 2 6 7]
[1 4 9 5]
"""