Tensorflow两种数据读取方法应用、对比及注意事项

本文对比介绍了两种Tensorflow针对大数据集的数据读取方法,具体来说是:

方法1:tf.train.slice_input_producer+流水线(这里用这个API指代这一类,其实还有其他的API)。

方法2:Dataset方法,据说是Tensorflow 1.3版本之后引入的新API。使用起来比方法1更方便直观。


关于这两种方法的具体介绍,这里不再赘述,建议参考何之源的这两篇文章:

https://zhuanlan.zhihu.com/p/27238630

https://zhuanlan.zhihu.com/p/30751039


下面针对一组简单的示例数据集,分别用这两种方法实现数据读取,以进行对比学习。

数据集非常简单,就是下面这三张jpg图片文件名分别为:a.jpg, b.jpg, c.jpg。来源是何之源新书《21个项目玩转深度学习——基于TensorFlow的实践详解》第二章。



注意,这三张图片的分辨率不同,稍后会提到,这是使用中需要注意的一点。我们的目的是使用上述两种方法把这三张图分批读入。


方法1,代码如下:

# coding:utf-8
# blog.csdn.net/foreseerwang
# QQ: 50834

import tensorflow as tf
import numpy as np

def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string, channels=3)
  image_resized = tf.image.resize_images(image_decoded,
                  tf.convert_to_tensor([28, 28], dtype=tf.int32))
  return image_resized, label

filename = ['A.jpg', 'B.jpg', 'C.jpg']
labels = [1,2,3]

images_tensor = tf.convert_to_tensor(filename, dtype=tf.string)
labels_tensor = tf.convert_to_tensor([1,2,3], dtype=tf.int64)
file_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=1)

filenames = file_queue[0]
labels = file_queue[1]
image_resized, labels = _parse_function(filenames, labels)

value = tf.train.shuffle_batch([image_resized, labels], batch_size=2, capacity=5000,
                       min_after_dequeue=1000)

batNum = 0
with tf.Session() as sess:
    tf.local_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        while not coord.should_stop():
            image_data = sess.run(value)
            batNum += 1
            print('** Batch %d' % batNum)
            print(image_data[0].shape)
            print(image_data[1])
            print('\n')

    except tf.errors.OutOfRangeError:
        print('end!')

    finally:
        coord.request_stop()

    coord.join(threads)

输出:

** Batch 1
(2, 28, 28, 3)
[1 3]

end!

注意,这里有个问题:一共3张图片,要求每个batch有两张,因为3不能被2整除,最后剩下一张图片,没能输出。


方法2,代码如下:

# coding:utf-8
# blog.csdn.net/foreseerwang
# QQ: 50834

import tensorflow as tf

def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string, channels=3)
  image_resized = tf.image.resize_images(image_decoded,
                  tf.convert_to_tensor([28, 28], dtype=tf.int32))
  return image_resized, label

filename = ['A.jpg', 'B.jpg', 'C.jpg']
labels = [1,2,3]

filename_tensor = tf.convert_to_tensor(filename, dtype=tf.string)
labels_tensor = tf.convert_to_tensor(labels, dtype=tf.int64)

dataset = tf.contrib.data.Dataset.from_tensor_slices((filename_tensor, labels_tensor))
dataset = dataset.map(_parse_function)

dataset = dataset.shuffle(buffer_size=10).batch(2).repeat(1)

iterator = dataset.make_one_shot_iterator()
one_batch = iterator.get_next()

batNum = 0
with tf.Session() as sess:
    try:
        while True:
            databatch = sess.run(one_batch)
            batNum += 1
            print('** Batch %d' % batNum)
            print(databatch[0].shape)
            print(databatch[1])
            print('\n')
            
    except tf.errors.OutOfRangeError:
            print("end!")

输出(因为有shuffle,每次运行结果可能不同):

** Batch 1
(2, 28, 28, 3)
[1 3]

** Batch 2
(1, 28, 28, 3)
[2]

end!

这里就看到方法2的优势了,都是3张图片,按照每个batch 2张输出,Dataset方法就可以把所有数据都输出,最后一点残余都不剩下。


这些代码都比较简单,就不进一步解读了。需要说明的是:

1. 以上代码均在Tensorflow 1.3版本下运行通过。据说在1.4版本以后,Dataset API被提高了层级,变为:tf.data.Dataset,在使用中还请注意。

2. 上面_parse_function()子函数中的tf.image.resize_images操作是必需的,必须把所有图片修改为相同尺寸,否则在进行batch操作时会出错。当然,要求是每个batch里的数据维度必须相同,如果batch_size=1,那可以不进行图片resize。

3. _parse_function()子函数其实是一个通用模块,在实际应用中根据需要完全可能是其它更复杂的数据读取和处理过程。


通过以上代码可以看出,方法2 Dataset方法起码具有如下两个优势:

1. 不用进行复杂的流水线管理,没有coord/threads那些语句,应该是系统内自动管理了;

2. 可以把数据集完整的送出来,而不会留下因为数据条数无法对batch_size整除而剩下的尾巴数据(我不确定是否有什么方法可以解决这个问题,如有人知道,还请指教。谢谢!)。


猜你喜欢

转载自blog.csdn.net/foreseerwang/article/details/80170210