tensorflow 中dataset常规使用方法

tf.data.Dataset.from_tensor_slices(array)获取分片数据,将输入的array按照第0维度分片
dataset.make_initializable_iterator() 创建dataset迭代器,需要进行初始化

one_element = iterator.get_next() 从迭代器中获取一个iter的数据
iterator.initializer 迭代器初始化
import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.random.uniform(size=(5, 2)))
iterator = dataset.make_initializable_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    for i in range(5):
        print(sess.run(one_element))
[0.89118383 0.24899264]
[0.31330564 0.38645845]
[0.92109775 0.89482123]
[0.3567731 0.9308349]
[0.72154475 0.12027287]
arr = np.arange(100)
dataset = tf.data.Dataset.from_tensor_slices(arr)
dataset = dataset.batch(20).repeat(5)
iterator = dataset.make_initializable_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    try:
        while True:
            print(sess.run(one_element))
    except:
        print("end")
arr = np.arange(100)
dataset = tf.data.Dataset.from_tensor_slices(arr)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.map(lambda x: x+1).batch(20)
iterator = dataset.make_initializable_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    for i in range(5):
        print(sess.run(one_element))

dataset中常见的三个transform方法为map,batch,shuffle

dataset.map()用来对dataset的元素进行操作

dataset.batch(num)用来改变dataset的结构,使数据变为以num为一个batch

dataset.shuffle()用来打乱dataset

file_path = "E:\\tf_project\\NMT\\train.txt.zh"
dataset = tf.data.TextLineDataset(file_path)
dataset = dataset.batch(10)
iterator = dataset.make_initializable_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    for i in range(3):
        print(sess.run(one_element))
tf.data.TextLineDataset(file_path) 是读入文本操作

参考文档:  点击打开链接


猜你喜欢

转载自blog.csdn.net/z2539329562/article/details/80753355