tensorflow学习(11):数据集Dataset

本文将介绍Dataset的基本使用方法,包括从文件创建数据集,使用迭代器遍历数据集等。

数据集的基本使用方法

在数据集框架中,每一个数据集代表一个数据来源:数据可能来自于一个张量,一个TFRecord文件,一个文本文件等等。由于训练数据通常无法全部写入内存中,从数据集中读取数据时需要一个迭代器(iterator)按顺序进行读取,这点与队列的dequeue()操作和Reader的read()操作相似。
1.从一个张量创建一个数据集,遍历这个数据集,并对每个输入输入y = x2 的值。

import tensorflow as tf
input_data = [1,2,3,5,8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)
iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()
y = x * x
with tf.Session() as sess:
    for i in range(len(input_data)):
        print(sess.run(y))

运行结果如下:

在这里插入图片描述
从上面这个简单的例子中可以看到,利用数据集读取数据有三个基本步骤:
1)定义数据集的构造方法,如tf.data.Dataset.from_tensor_slices()
2)定义遍历器,如make_one_shot_iterator()
3)使用get_next()方法从遍历器中读取数据张量

2.在图像相关任务中,输入数据通常以TFRecord形式存储。我们需要从TFRecord文件中读取数据,并创建一个数据集。由于每一个TFRecord文件都有自己不同的feature格式,因此在读取TFRecord时,需要提供一个parse函数来解析所读取的TFRecord的数据格式

import tensorflow as tf
def parse(record):
    features = tf.parse_single_example(
        record,
        features={
            'image':tf.FixedLenFeature([],tf.int64),
            'label':tf.FixedLenFeature([],tf.int64)
        })
    return features['image'],features['label']

#从TFRecord文件创建数据集
input_files = ["path/to/input_file1","path/to/input_file2"]  #可以是多个文件
dataset = tf.data.TFRecordDataset(input_files)
#map()函数表示对数据集中的每条数据进行调用相应方法。使用TFRecordDataset读出的是
#二进制数据,这里需要通过map()来调用parse()对二进制进行解析。类似的,map()也可以
#用来完成其他的数据预处理工作
dataset = dataset.map(parse)

iterator = dataset.make_one_shot_iterator()
image,label = iterator.get_next()
with tf.Session() as tf:
    for i in range(10):
        a1,a2 = sess.run([image,label])

上述例子使用了最简单的one_shot_iterator来遍历数据集。在使用one_shot_iterator时,数据集的所有参数必须已经确定,因此one_shot_iterator不需要特别的初始化过程,如果要用placeholder来初始化数据集,那就需要用到initializable_iterator:

import tensorflow as tf
def parse(record):
    features = tf.parse_single_example(
        record,
        features={
            'image':tf.FixedLenFeature([],tf.int64),
            'label':tf.FixedLenFeature([],tf.int64)
        })
    return features['image'],features['label']

input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parse)

iterator = dataset.make_initializable_iterator()
image,label = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer,
             feed_dict={input_files:["path/to/input_file1","path/to/input_file2"]})
    while True:
        try:
            sess.run([image,label])
        except tf.errors.OutOfRangeError:
            break

猜你喜欢

转载自blog.csdn.net/shanlepu6038/article/details/85061918
今日推荐