Tensorflow使用Dataset读取数据

(1)one shot iterator 单次迭代

仅支持对数据集进行一次迭代,不需要显式初始化。

import tensorflow as tf

# 通过对list切片的方式创建数据集
dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 6], [0, 0, 1, 1, 2, 2]))
# 打乱顺序,设置batch,设置重复次数
dataset = dataset.shuffle(buffer_size=1000).batch(4).repeat(2)
# 创建迭代器
itr = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    while True:
        v = sess.run([itr])
        print(v)

运行结果:

[(array([1, 2, 5, 6]), array([0, 0, 2, 2]))]
[(array([4, 3]), array([1, 1]))]
[(array([5, 4, 1, 6]), array([2, 1, 0, 2]))]
[(array([2, 3]), array([0, 1]))]
Traceback (most recent call last):
  File "D:\App\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call

如果有101个样本,batch size是100,那么这个函数会返回两个batch,第一个是100,第二个是1。

如果想要强制每个batch 的数目是一样的那么就在batch 的时候设置drop_remainder=True。这样101个样本,只会返回一个数目是100的batch,剩下的一个样本会被丢弃。

dataset.shuffle(buffer_size=1000).batch(128, drop_remainder=True).repeat(20000)

(2)initializable iterator 带参数的迭代器

这种迭代器允许使用placeholder,在运行时传送参数。

比如下面的例子,创建一个从0~p的数据集,p是参数,在运行时传入p。

import tensorflow as tf

# 对数据集传入参数 p
p = tf.placeholder(tf.int64, shape=[])

# 创建从0到p-1的数据集
dataset = tf.data.Dataset.range(p)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    # Initialize an iterator over a dataset with 3 elements.
    sess.run(iterator.initializer, feed_dict={p: 3})
    for i in range(3):
        print(sess.run(next_element))

    # Initialize the same iterator over a dataset with 10 elements.
    sess.run(iterator.initializer, feed_dict={p: 10})
    for i in range(10):
        print(sess.run(next_element))

(3)可切换数据集的迭代器

这种情况下不同Dataset的尺寸要一致。

import tensorflow as tf

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(10).map(lambda x: x - 100).batch(5)
validation_dataset = tf.data.Dataset.range(10).batch(5)

# 根据training_dataset的形状创建迭代器
iterator = tf.data.Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

with tf.Session() as sess:
    # Run 20 epochs in which the training dataset is traversed, followed by the
    # validation dataset.
    for _ in range(5):
        print("======")
        # Initialize an iterator over the training dataset.
        sess.run(training_init_op)
        for _ in range(2):
            print(sess.run(next_element))
        print("------")
        # Initialize an iterator over the validation dataset.
        sess.run(validation_init_op)
        for _ in range(1):
            print(sess.run(next_element))

运行结果

======
[-100  -99  -98  -97  -96]
[-95 -94 -93 -92 -91]
------
[0 1 2 3 4]
======
[-100  -99  -98  -97  -96]
[-95 -94 -93 -92 -91]
------
[0 1 2 3 4]
====== (后面省略)

validation从未输出5~10,可以得出每一次执行初始化都会丢失之前的进度。

(4)可切换数据集并保留进度的迭代器

可feed迭代器可以与 tf.placeholder 一起使用,通过feed_dict机制选择每次调用 tf.Session.run 时所使用的 Iterator。它功能与(3)相同,但在迭代器之间切换时不需要从数据集的开头初始化迭代器。

import tensorflow as tf

p = tf.placeholder(tf.int64, shape=[])

# 定义训练数据集和验证数据集。两个数据集的结构相同。
training_dataset = tf.data.Dataset.range(10).map(lambda x: x - 100).batch(5).repeat(1000)
validation_dataset = tf.data.Dataset.range(p).batch(5).repeat(1000)

# 这种迭代器需要传入一个handle和数据集的结构
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# 创建两个不同的迭代器(训练迭代器和验证迭代器)
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

with tf.Session() as sess:
    # 获得两个迭代器的句柄
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())
    sess.run(validation_iterator.initializer, feed_dict={p: 10})

    while True:
        print("======")
        # 通过feed_dict=handle在两个迭代器之间切换。现在换成训练迭代器。
        for _ in range(1):
            print(sess.run(next_element, feed_dict={handle: training_handle}))
        print("------")
        # 现在换成迭代训练器。
        for _ in range(1):
            print(sess.run(next_element, feed_dict={handle: validation_handle}))

运行结果

======
[-100  -99  -98  -97  -96]
------
[0 1 2 3 4]
======
[-95 -94 -93 -92 -91]
------
[5 6 7 8 9]
======
[-100  -99  -98  -97  -96]
------
[0 1 2 3 4]
======
[-95 -94 -93 -92 -91]
------
[5 6 7 8 9] (后面的部分省略)

从运行结果可以看出,切换数据集时确实保留了上次的进度。

发布了80 篇原创文章 · 获赞 22 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/u010099177/article/details/101057658