tensorflow 导入数据(2)

1、概述

在前一篇文章中详细讨论了迭代器与数据集的相关内容。由于数据集与迭代器是链接原始数据与程序连接的渠道,所以本文主要讨论如何从原始数据中构建数据集,主要涉及以下场景:

  • 内存
  • TFRecord data
  • 文本文件
  • csv文件

2、从内存中读取数据

如果所有的数据都以numpy数据组的形式预先保存到了内存当中,那么我们使用Dataset.from_tensor_slices()方法可以非常方便的将一个这样的数组转化为tensorflow的张量对象。下面以手写数字的数据集为例来说明这个场景的应用。

1)首先像我们在使用卷积神经网络处理手写数字一样,我们先来下载数据集,参考代码属下:

# mnist数据集
import tensorflow  as tf
mnist = tf.contrib.learn.datasets.load_dataset("mnist")

2)由于只是讨论数据的处理,这里我们使用数据量比较小的测试数据集

import numpy as np
eval_data = mnist.test.images 
print(eval_data.shape)
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
print(eval_labels.shape)

输出结果如下:

3)现在这个数据就是一个numpy array 的数据集,下面将其转化为一个Dataset,同时构造一个one-shot的迭代器来读取里边的数据

dataset = tf.data.Dataset.from_tensor_slices((eval_data, eval_labels))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

就目前tensorflow的机制来讲,所有保存在内存中的数据都需要转化为numpy数组才能够被更好的处理,如果是通过pandas读取的数据需要先转化为nd array才能够输入给tensorflow。

3、读取TFRecord 格式的数据

由于数据量非常大并不适用于直接读入到内存当中,只能从磁盘文件进行数据读入,tf.data支持各种格式的文件读入。TFRecord文件格式是一种简单的面向记录的二进制格式,许多TensorFlow应用程序使用它来训练数据。

1)关于TFRecord

tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:

#feature一般是多维数组,要先转为list
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) 
#tostring函数后feature的形状信息会丢失,把shape也写入
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape)))  
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

我们可以使用TFRecordWriter类来实现TFRecord格式的生成,下面是生成的详细代码:

import numpy as np
eval_data = mnist.test.images 
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
writer = tf.python_io.TFRecordWriter('tfdata/img.tfrecords')
for i in range(len(eval_labels)):
    row_data = eval_data[i]
    labels = eval_labels[i]
    example = tf.train.Example(
           features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[labels])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[row_data.tostring()]))
           }))
        # 将信息写入指定路径
    writer.write(example.SerializeToString())

下面可以使用dataset来读取数据,请参照如下代码:

def parse_data(data):
    feats = tf.parse_single_example(data, features={'img_raw':tf.FixedLenFeature([], tf.string),'label':tf.FixedLenFeature([],tf.int64)})
    image = tf.decode_raw(feats['img_raw'], tf.float32)
    label = feats['label']
    return image, label
dataset = tf.data.TFRecordDataset('tfdata/img.tfrecords')
dataset = dataset.map(parse_data)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    value = sess.run(next_element)
    print(value)

4、读取文本文件

很多数据集都是作为一个或多个文本文件分布的。tf.data.TextLineDataset 提供了一种从一个或多个文本文件中提取行的简单方法。给定一个或多个文件名,TextLineDataset 会为这些文件的每行生成一个字符串值元素。像 TFRecordDataset 一样,TextLineDataset 将 filenames 视为 tf.Tensor,因此您可以通过传递 tf.placeholder(tf.string) 来进行参数化,请参照如下代码:

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)

猜你喜欢

转载自blog.csdn.net/amao1998/article/details/81260341