TensorFlow之数据集(dataset)

学习记录

前面学习过队列和线程用于读取训练数据,今天学习《TensorFlow实战Google深度学习框架》一书中的另外一种数据的读取方式:通过数据集读取训练数据。
利用数据集读取数据有三个基本步骤:

  1. 定义数据集的构造方式。
  2. 定义迭代器。
  3. 使用get_next()方法从迭代器中读取数据张量,作为计算图其他部分的输出。

数据集的基本使用方法

首先注明两点:

  1. 在数据集框架中,每一个数据集代表一个数据来源:数据可能来自一个张量,一个文本文件,或者一个TFRecord文件等。
  2. 用于训练数据通常无法全部写入内存中,所以从数据集读取数据时需要使用一个迭代器(Iterator)按顺序进行读取,这点与队列的dequeue()操作以及Reader的read()操作类似。

在正式叙述利用数据集读取训练数据之前,先集中了解一下数据集的创建。

数组创建数据集

在正式介绍利用队列与多线程读取数据之前,先给出一个简单的程序来生成样例数据。

# 从一个数组创建数据集
input_data = [1,2,3,5,8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)

# 定义一个迭代器用于遍历数据集。
# 因为上面定义的数据集没有用placeholder作为输入参数,
# 所以这里可以使用最简单的one_shot_iterator
iterator = dataset.make_one_shot_iterator()

# get_next()返回代表一个输入数据的张量,类似于队列的dequeue()
x = iterator.get_next()
y = x * x


with tf.Session() as sess:
    for i in range(len(input_data)):
        print(sess.run(y))
  1. tf.data.Dataset.from_tensor_slices()
    用于从一个数组创建一个数据集。

  2. make_one_shot_iterator() 产生一个单次的迭代器。这种迭代器进支持对数据进行一次迭代,迭代完成后如还要继续迭代就会报错:OutOfRangeError: End of sequence

  3. get_next()
    用于从迭代器中取出并返回一个数据

文本文档创建数据集

# 从文本文档创建数据集
input_files = [r'C:\Users\USER\Desktop\hekai.txt',r'C:\Users\USER\Desktop\1.txt']
dataset = tf.data.TextLineDataset(input_files)

iterator = dataset.make_one_shot_iterator()

x = iterator.get_next()

with tf.Session() as sess:
    for i in range(10):
        print(sess.run(x))
  1. tf.data.TextLineDataset()
    利用文本文档创建数据集

TFRecord文件创建数据集

# 从文本文档创建数据集
# 解析一个TFRecord的方法
def parser(record):
    features = tf.parse_single_example(
            record,
            features={
                    'image_raw': tf.FixedLenFeature([], tf.string),
                    'height': tf.FixedLenFeature([], tf.int64),
                    'width': tf.FixedLenFeature([], tf.int64),
                    'depth': tf.FixedLenFeature([], tf.int64),
                    'label': tf.FixedLenFeature([], tf.int64)
                    })
    return features

# 利用tf.data.TFRecordDataset()从TFRecord文件创建数据集
input_files = [r'D:\Anaconda3\code_hk\mobilenet_hk\1.tfrecords'] 
# 这里是一个文件列表,可以有多个TFRecord文件
dataset = tf.data.TFRecordDataset(input_files)

# map()方法表示对数据集中的每一条数据进行相应的调用方法。
# 使用TFRecordDataset读出的是二进制的数据,
# 这里需要通过map()来调用parser()对二进制数据进行解析
# 类似的,map()也可以用来完成其他数据处理工作
dataset = dataset.map(parser)

# 定义遍历数据集的迭代器
iterator = dataset.make_one_shot_iterator()

# 返回下一个样本数据
feat = iterator.get_next()

# 对于图片而言,还得进一步还原才能显示出来
img = tf.decode_raw(feat ['image_raw'], tf.uint8)
img = tf.reshape(img, [feat['height'],feat['width'],feat['depth']])

with tf.Session() as sess:  
    for i in range(1,3):
        image = sess.run(img)
        plt.figure()
        plt.imshow(image)

与文本文档创建数据集不同,每一个TFRecord文件都有自己不同的feature格式,因此在读取时需要提供一个parse函数来解析所读取的数据。

但有时我们不想讲文件路径写死,这时候我们就需要先用一个placeholder来初始化数据集,那么这时,对于这样的数据集,我们就不能再使用make_one_shot_iterator()来生成一个迭代器了。这时我们应该换成make_initializable_iterator()。但是在这种情况下我们需要注意一点:上一段程序中没有进行初始化是因为在使用make_one_shot_iterator()时,数据集所有的参数都已经确定了,但是使用make_initializable_iterator()时,我们则必须要在session中首先初始化iterator。

下面是使用make_initializable_iterator()的一个样例。

def parser(record):
    features = tf.parse_single_example(
            record,
            features={
                    'image_raw': tf.FixedLenFeature([], tf.string),
                    'height': tf.FixedLenFeature([], tf.int64),
                    'width': tf.FixedLenFeature([], tf.int64),
                    'depth': tf.FixedLenFeature([], tf.int64),
                    'label': tf.FixedLenFeature([], tf.int64)
                    })
    return features

# 这里给出的具体路径是一个placeholder,稍后再提供具体的路径
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)

dataset = dataset.map(parser)

# 由于需要placeholder来初始化数据集,所以这里用initializable_iterator
iterator = dataset.make_initializable_iterator()

# 返回下一个样本数据
feat = iterator.get_next()

# 对于图片而言,还得进一步还原才能显示出来
img = tf.decode_raw(feat ['image_raw'], tf.uint8)
img = tf.reshape(img, [feat['height'],feat['width'],feat['depth']])

with tf.Session() as sess:
    # 首先初始化iterator,并给出input_files的值
    sess.run(iterator.initializer,
             feed_dict={input_files:[r'D:\Anaconda3\code_hk\mobilenet_hk\1.tfrecords']}
             )
    while True:
       try:
           image = sess.run(img)
           plt.figure()
           plt.imshow(image)
       except tf.errors.OutOfRangeError:
           break

数据集的高层操作

在上面主要介绍了数据集的基础用法。在接下来的这一部分,将介绍数据集框架提供的一些方便实用的高层API。

map

map是在数据集上进行操作的最常用的方法之一。其实这个方法在上面的代码中已经使用了。在下面的代码中,map(parse)方法表示对数据集中的每一条数据调用指定的方法parse。对每一条数据进行处理之后,map将处理后的数据包装成一个新的数据集返回。这是一个非常高效的方法。

shuffle

shuffle的主要功能是将数据集打乱。其中shuffle有一个参数buffer_size。它的功能类似于队列框架下的tf.train.shuffle_batch中的min_after_dequeue参数所带来的的打乱的功能。shuffle算法在内部使用一个缓冲区来保存buffer_size条数据。缓冲区越大,随机打乱的性能就越好,但同样的占用的内存也越多。

batch

batch的主要功能就是通过参数batch_size指定输出的每个batch有多少条数据组成。它的功能类似于队列框架下的tf.train.batch的作用。

repeat

repeat是数据集框架中另一个比较常用的操作方法。这个方法可以将数据集中的书数据复制成多份,在训练过程中,每一份被称之为一个epoch。但是在使用这个操作方法的过程中有一点需要注意的是repeat是计算图上的一个计算节点,代表重复的操作过程,所以,如果在repeat之前使用过了shuffle等操作,repeat并不是将shuffle过后的数据集重复,而是连同shuffle操作重复。

因为这几个概念都比较简单,就都在同一段代码中做演示了。

# 列举输入文件。训练和测试使用不同的数据
train_files = tf.train.match_filenames_once('train_file-*')
test_files = tf.train.match_filenames_once('test_file-*')


# 定义parser方法从TFRecord中解析数据
# 这里是按照之前自己生成的一个图片数据集的TFRecord文件进行解析的
def parser(record):
    features = tf.parse_single_example(
            record,
            features={
                    'image_raw': tf.FixedLenFeature([], tf.string),
                    'height': tf.FixedLenFeature([], tf.int64),
                    'width': tf.FixedLenFeature([], tf.int64),
                    'depth': tf.FixedLenFeature([], tf.int64),
                    'label': tf.FixedLenFeature([], tf.int64)
                    })
    
    # 从原始图像数据解析除像素矩阵,并根据图像尺寸还原图像
    decoded_image = tf.decode_raw(features['image_raw'], tf.unit8)
    decoded_image = tf.reshape(decoded_image, 
                               [features['height'],features['width'],features['depth']])
    
    label = features['label']
    
    return decoded_image, label


image_size = 299  # 定义神经网络输入层图片的大小
batch_size = 100  # 定义组合数据batch的大小
buffer_size = 10000  # 定义随机打乱数据时buffer的大小

# 创建一个数据集
dataset = tf.data.TFRecordDataset(train_files)
dataset = dataset.map(parser)

# 对数据一次进行预处理、shuffle、batching操作。
# preprocess_for_train为7.2.2节中定义的一个图像预处理的函数。
# 因为上一个map得到的数据集中提供了decoded_image, label两个结果,
# 所以这个map需要提供有两个返回值的函数来处理数据。
# 在下面的代码中,lambda中的image代表的就是第一个map返回的decoded_image,
# label代表的是第一个map返回的label,这个lambda先调用preprocess_for_train
# 处理图片,然后再讲处理好的图片和label组成最终的输出
dataset = dataset.map(
        lambda image,label: (
                preprocess_for_train(image, image_size, image_size, None), label))

dataset = dataset.shuffle(buffer_size).batch(batch_size)

# 这里指定了整个数据集的重复次数,也就间接的确定了训练的轮数
num_epoch = 10
dataset = dataset.repeat(num_epoch)

# tf.train.match_filenames_once方法得到的文件列表也是存在不确定性的,
# 和placeholder的机制类似,所以这里也需要初始化,故用initializable_iterator
iterator = dataset.make_initializable_iterator()
# 经过上面的处理,这里得到的就是batch好的一批批数据了
image_batch, lanel_batch = iterator.get_next()

猜你喜欢

转载自blog.csdn.net/weixin_43923472/article/details/89676813