Tensorflow dataset

相关用法总是会忘,特此转载记下,同时在此向原创作者表示感谢。

Dataset是Tensorflow里面一个比较重要的概念,我们知道机器学习算法需要大概的数据来训练data model. 所以Dataset就是用来做这么一件重要的事情:定义数据pipline,为学习算法提供训练数据。

其实我们也可以将Dataset理解成一个数据源,指向某些包含训练数据的文件列表,或者是内存里面已有的数据结构(比如Tensor objects)。


Dataset 数据结构

组成Dataset的基本单元是element。每个element必需有相同的数据结构,其中每个element包含多个Tensor objects。比如:

# 创建一个dataset,里面包含一个2-Dimension (4x10) Tensor对象
dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))

# 创建一个dataset,里面包含两个Tensor, tensor1的shape为(4x3), tensor2的shape为(4x5)
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4, 3]), tf.random_uniform([4, 5])))

创建Dataset

前面说了Dataset可以理解成数据源, 那么怎么创建一个Dataset并使它跟多个数据源关联呢?Tensorflow Dataset API提供了两种方式:

  1. 从已有的一个或者多个Tensors对象中创建
    上一节的Dataset.from_tensor_slices()就是这用这种方式创建的Dataset
    利用这种方式,同样地可以创建指向训练数据文件的Dataset,比如我们让每个element包含两个Tensor, 第一个Tensor指向一堆汽车的图片文件,另外一个vector tensor表示对应的图片是否为一辆卡车:

    train_imgs = tf.constant(['train/img1.png', 'train/img2.png',
                                                 'train/img3.png', 'train/img4.png',
                                                  'train/img5.png', 'train/img6.png'])
    train_labels = tf.constant([0, 0, 0, 1, 1, 1])
    tr_data = Dataset.from_tensor_slices((train_imgs, train_labels))

    这样dataset里面的每一个element其实就是一个tuple,包含了(feature, label)

  2. 对已有的Dataset进行转换(transformation),比如batch(), map(), filter(),后面会再介绍这些常用的API

    dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
    
    dataset2 = dataset1.batch(10)

    dataset2就是用这里介绍的第二种方法创建的。


读取Dataset

从前面Dataset的定义以及结构可以看出,Dataset其实是对Tensor提供了一层封装,而Tensor又是对真实的训练数据的封装,这些数据可能是一个N-Dimension matrix,或者是指向一批数据文件的向量。其实我们可以会问为什么要设计的这么复杂,又是matrix, 又是Tensor的,直接用Tensor/Matrix的API来读取训练数据不就行了么? 我觉得可以从下面几个方向来思考:

  • 在训练我们的model的时候,需要把训练数据input到我们的算法model中。但有时候训练数据不是说只有几百条,而是成千上万的,这样如果直接把这些数据load到内存中的Tensor肯定是吃不消的,所以需要一种数据结构让算法能够批量地从disk中分批读取,然后用它们来训练我们的model, Dataset正是提供这种机制(transformation)来满足这方面的需求。
  • 相比Tensor,Dataset对训练数据的读取更加灵活。当我们用常用的梯度下降算法来minimize我们的cost function时,需要不断地调整parameter的数值从而使cost不断地下降。这是一个迭代过程,每个迭代都需要读取不同batch size的训练数据来计算cost。Dataset提供了一些丰富的API可以读取不同batch size的数据。

回到正题,Dataset提供Iterator.get_next() API来读取它的每一个element,这个element包含一个或者多个我们需要的Tensor objects。

至于每次调用get_next()返回多少个element,则取决于batch size的大小。或者你可以认为batch size就是决定每次读取多少个训练数据,一个训练数据就是一个element。

Iterator的调用步骤:

  1. 定义一个Dataset

        dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 5]))
        #dataset = dataset.batch(2)
    
  2. 定义一个Iterator

        iterator = dataset.make_initializable_iterator()
        next_element = iterator.get_next()  
  3. 初始化Iterator (one shot iterator 除外), 如果有parameter需要初始化,将初始化的值传递给feed_dict

    sess.run(iterator.initializer, feed_dict={...})
  4. 用Iterator读取数据

    sess.run(next_element)   
    
    # output [ 0.58478916  0.3431859   0.23752177  0.19337153  0.05314612]
    

    如果将dataset的batch size定义成2,那么next element将会包含两个数组:

    sess.run(next_element)
    
    #output
    
    [[ 0.38093257  0.31324649  0.16414177  0.84969711  0.40212131]
     [ 0.18354928  0.55987918  0.09232235  0.98887277  0.21049285]]

这里需要特别提一下one shot iterator,它每次只读取一个element,而且 这种Iterator不需要初始化,也就是上面的第3步不需要显式地调用。但是只有当Dataset不包含任何参数时才可以为它创建one shot iterator, 前面例子里的Dataset都不能创建one shot iterator。
你可以这样来创建one shot iterator:

dataset2 = tf.data.Dataset.from_tensor_slices(tf.constant([[1, 2, 3], [2, 4, 6], [3, 6, 9]]))
iter2 = dataset2.make_one_shot_iterator()

用Dataset读取文件

前面的例子里很多的都是从Ternsor对象中创建Dataset, 所以用Iterator读取到的可能是一些常量数据,比如文件名,数组之类的。但是在真实的世界中,训练数据都是存放在文件中的,比如CSV,JPG,所以我们关心的其实并不是这些文件名本身,还是其中的内容。那么如果我的Tensor中存放的是一些文件名字,怎么用Dataset来读取其中的数据呢?

Dataset提供了一个数据预处理的API map()。 预处理的意思是可以对每一个element进行transformation,Iterator的get_next()拿到的可能是一个字符串代表某个文件名或者CSV文件里的一行,然后transformation的时候将这个文件的内容读取出来并保存在内存的Tensor对象。

读取文本文件

这里用TextLineDataset读取csv文件:

def readTextFile(filename):
    _CSV_COLUMN_DEFAULTS = [[1], [0], [''], [''], [''], [''],['']]
    _CSV_COLUMNS = [
    'age', 'workclass', 'education', 'education_num',
    'marital_status', 'occupation', 'income_bracket'
]

    dataset = tf.data.TextLineDataset(filename)
    iterator = dataset.make_one_shot_iterator()
    textline = iterator.get_next()

    with tf.Session() as sess:
        print(textline.eval())

    # convert text to list of tensors for each column
    def parseCSVLine(value):
        columns = tf.decode_csv(value, _CSV_COLUMN_DEFAULTS)
        features = dict(zip(_CSV_COLUMNS, columns))
        return features

    dataset2 = dataset.map(parseCSVLine)
    iterator2 = dataset2.make_one_shot_iterator()
    textline2 = iterator2.get_next()  

    with tf.Session() as sess:
        print(textline2)

这里parseCSVLine 将从csv读取到的每一行进行decode 处理(tf.decode_csv), 从而将每一列转成对应的Tensor object。

读取图片文件

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
版权声明:本文为博主原创文章,转载请注明出处。 https://blog.csdn.net/west_609/article/details/78608541

猜你喜欢

转载自blog.csdn.net/tiankongtiankong01/article/details/80115873