Tensorflow Dataset.from_generator使用示例

之前的一篇博文(https://blog.csdn.net/foreseerwang/article/details/80170210)介绍了使用Tensorflow Dataset进行数据导入的方法及其优势。最近在实际使用中越发感觉到这个方式非常好用,尤其是发现了.from_generator这个method。

关于Dataset.from_generator的简单介绍,请参见如下两个链接:

https://tensorflow.google.cn/versions/master/api_docs/python/tf/data/Dataset#repeat

https://blog.csdn.net/dqcfkyqdxym3f8rb0/article/details/79342369

注意,Dataset.from_generator在旧版Tensorflow中没有,起码在1.3版本tf.contrib.data.Dataset中还没有,后来用的1.7版本就有了。

我们知道,tensorflow的基本原理是先构造一个计算图,最后再统一计算。为此,tf重写了几乎所有常见函数,用于构造计算图,而且tensorflow不支持循环、选择等普通编程语言的常见操作。这就给编程使用带来比较大的麻烦。具体到data feeding上,也是如此。虽然设计了placeholder、train.slice_input_producer系列、Dataset等多种方式,但使用中仍有各种不便,尤其是在输入形式复杂、需要多重变换的时候更是如此。而Dataset.from_generator可以在一定程度上解决这个问题。

简单的说,Dataset.from_generator可以使用普通编程语言编写的外部子函数生成Dataset,这样几乎不受tensorflow编程不便的影响。先举一个最简单的示例:

# demo of Dataset.from_generator
# blog.csdn.net/foreseerwang
# QQ: 50834

"""
Expected outputs:

Batch No. 0:
[0 1 2 3]

Batch No. 1:
[4 0 1 2]

Batch No. 2:
[3 4 0 1]

Batch No. 3:
[2 3 4]

end!
"""

import numpy as np
import tensorflow as tf

def data_generator():
    dataset = np.array(range(5))
    for d in dataset:
        yield d

dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32), (tf.TensorShape([])))
dataset = dataset.repeat(3)
dataset = dataset.batch(4)

iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()

with tf.Session() as sess:
    try:
        batch_num=0
        while True:
            one_batch = sess.run(one_element)
            print('Batch No. %d:' % batch_num)
            print(one_batch)
            print('')
            batch_num+=1

    except tf.errors.OutOfRangeError:
        print('end!')

很显然,这个的输出如下:

Batch No. 0:
[0 1 2 3]

Batch No. 1:
[4 0 1 2]

Batch No. 2:
[3 4 0 1]

Batch No. 3:
[2 3 4]

end!

下面给出一个复杂的问题。假设需要输入如下序列:

A B

A C B

C

其中A/B/C分别代表一个文件,例如一张图片或是一个文本文件。每一行是一条记录,按行读入,并聚集多行形成batch,譬如每4行形成一个batch。这里有两个难点:1.每一行/每一条记录的元素长度不一样;2.读入元素A/B/C之后还要以之作为文件名读入文件内容。现有各种data feeding方式似乎很难同时解决这两个难点,除了Dataset.from_generator。


针对这个问题,使用Dataset.from_generator的一个简化版示例如下:

# demo of Dataset.from_generator
# blog.csdn.net/foreseerwang
# QQ: 50834

"""
Expected outputs:

Batch No. 0:
[[ 1  2  3]
 [ 2  3 -1]]

Batch No. 1:
[[ 3 -1 -1]
 [ 4  5 -1]]

Batch No. 2:
[[ 6  7  8]
 [ 9 -1 -1]]

Batch No. 3:
[[10 11 12]
 [13 14 -1]]

Batch No. 4:
[[15 -1 -1]]

end!
"""

import io
import numpy as np
import tensorflow as tf

class DataFeeder:

    def __init__(self, filenames):
        self.filenames = filenames

    def file_readline(self):
        for filename in self.filenames:
            fr = io.open(filename, 'r', encoding='utf-8')

            while True:
                file_line = fr.readline()
                if not file_line:
                    break

                datalist = file_line.split()
                # if datalist is a list of filename, file contents can
                # be read and appendded here.
                yield np.asarray(datalist, dtype='int32')

            fr.close()

    def generate_batch(self, batch_size, num_epochs=None):
        dataset = tf.data.Dataset.from_generator(self.file_readline,
                                                 tf.int32,
                                                 tf.TensorShape([None]))

        dataset = dataset.repeat(num_epochs)
        dataset = dataset.padded_batch(
            batch_size,
            padded_shapes=tf.TensorShape([3]),
            padding_values=-1)

        iterator = dataset.make_one_shot_iterator()
        out_batch = iterator.get_next()

        return out_batch

filenames = ['a.txt', 'b.txt', 'c.txt']
data_feeder = DataFeeder(filenames)
one_batch = data_feeder.generate_batch(batch_size=2, num_epochs=1)

with tf.Session() as sess:
    try:
        batch_num = 0
        while True:
            data_batch = sess.run(one_batch)
            print('Batch No. %d:' % batch_num)
            print(data_batch)
            print('')
            batch_num+=1

    except tf.errors.OutOfRangeError:
        print('end!')

其中三个文本文件a.txt/b.txt/c.txt的内容分别如下:

a.txt:

1 2 3
2 3
3

b.txt:

4 5
6 7 8
9

c.txt:

10 11 12
13 14
15

运行以上代码的输出为:

Batch No. 0:
[[ 1  2  3]
 [ 2  3 -1]]

Batch No. 1:
[[ 3 -1 -1]
 [ 4  5 -1]]

Batch No. 2:
[[ 6  7  8]
 [ 9 -1 -1]]

Batch No. 3:
[[10 11 12]
 [13 14 -1]]

Batch No. 4:
[[15 -1 -1]]

end!

目前的输出,每个batch是batch_size * 3的矩阵。实际上,1~15的数字可以是某个图片的文件名,在file_readline()函数中读出这些数字后,可以继续读出这些文件的内容,并形成更高维度的Dataset输出,譬如:batch_size * img_size * img_size * img_channel的Dataset。


最后,说几点注意事项(详见代码):

1. generator函数不能有输入参数,但如果是class内的一个函数,可以使用self参数,这也是传递参数的一个手段;

2. 上述class中,建议传递文件名,在generator中打开处理再关闭,而不应该在外面打开(fr=open(filename, ‘r’)),然后把fr传递给generator读取。实践表明:后面这种方法形成的dataset不能repeat;

3. 因为序列不等长,在形成dataset batch时需要使用Dataset.padded_batch方法。


猜你喜欢

转载自blog.csdn.net/foreseerwang/article/details/80572182