tensorflow读取文件数据

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_24548569/article/details/81747863

读取CSV文件数据

在tensorflow中读取CSV文件需要用到TextLineReaderdecode_csv

首先准备两个csv文件:
file1.csv 内容:

,2,3,4,11
1,,3,4,12
1,2,,4,13
1,2,3,,14

file2.csv 内容:

,2,3,4,21
1,,3,4,22
1,2,,4,23
1,2,3,,24

第5列(最后一列)作为样本的标签,1开头表示file1的样本,2开头表示file2的样本。csv文件中有空列。

读取csv文件代码:

import tensorflow as tf

# 生成文件名字符串张量列表,shuffle打乱文件名列表
filename_queue = tf.train.string_input_producer(["file1.csv", "file2.csv"], shuffle=True)

# 使用TextLineReader阅读器读取文件
reader = tf.TextLineReader()
# read方法每次读取一行数据,key表示数据所在的文件,value为读取的一行数据
key, value = reader.read(filename_queue)

# 设置默认数据,当读到空数据时使用默认数据
# 负号表示默认数据,数字表示列序号
record_defaults = [[-1], [-2], [-3], [-4], [0]]
# 解析csv数据
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
# 组成特征向量
features = tf.stack([col1, col2, col3, col4])

with tf.Session() as sess:
    # 将文件名填充到队列
    # 在调用 run 或 eval 执行 read 之前,必须调用 start_queue_runners 。否则 read 操作会被阻塞到文件名队列中有值为止。
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    # 读取数据
    for i in range(50):
        key_run, value_run, example, label = sess.run([key, value, features, col5])
        print("step %02d file: %s raw_data: %s" % (i, key_run, value_run), end=" ")
        print("features: ", example, "label: ", label)

    coord.request_stop()
    coord.join(threads)

两个csv文件8个样本,这里读取50个样本,说明可以重复读取样本。

输出结果如下:

step 00 file: b'file2.csv:1' raw_data: b',2,3,4,21' features:  [-1  2  3  4] label:  21
step 01 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features:  [ 1 -2  3  4] label:  22
step 02 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features:  [ 1  2 -3  4] label:  23
step 03 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features:  [ 1  2  3 -4] label:  24
step 04 file: b'file1.csv:1' raw_data: b',2,3,4,11' features:  [-1  2  3  4] label:  11
step 05 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features:  [ 1 -2  3  4] label:  12
step 06 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features:  [ 1  2 -3  4] label:  13
step 07 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features:  [ 1  2  3 -4] label:  14
step 08 file: b'file1.csv:1' raw_data: b',2,3,4,11' features:  [-1  2  3  4] label:  11
step 09 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features:  [ 1 -2  3  4] label:  12
step 10 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features:  [ 1  2 -3  4] label:  13
step 11 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features:  [ 1  2  3 -4] label:  14
step 12 file: b'file2.csv:1' raw_data: b',2,3,4,21' features:  [-1  2  3  4] label:  21
step 13 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features:  [ 1 -2  3  4] label:  22
step 14 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features:  [ 1  2 -3  4] label:  23
step 15 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features:  [ 1  2  3 -4] label:  24
step 16 file: b'file1.csv:1' raw_data: b',2,3,4,11' features:  [-1  2  3  4] label:  11
step 17 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features:  [ 1 -2  3  4] label:  12
step 18 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features:  [ 1  2 -3  4] label:  13
step 19 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features:  [ 1  2  3 -4] label:  14
step 20 file: b'file2.csv:1' raw_data: b',2,3,4,21' features:  [-1  2  3  4] label:  21
step 21 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features:  [ 1 -2  3  4] label:  22
step 22 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features:  [ 1  2 -3  4] label:  23
step 23 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features:  [ 1  2  3 -4] label:  24
step 24 file: b'file2.csv:1' raw_data: b',2,3,4,21' features:  [-1  2  3  4] label:  21
step 25 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features:  [ 1 -2  3  4] label:  22
step 26 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features:  [ 1  2 -3  4] label:  23
step 27 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features:  [ 1  2  3 -4] label:  24
step 28 file: b'file1.csv:1' raw_data: b',2,3,4,11' features:  [-1  2  3  4] label:  11
step 29 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features:  [ 1 -2  3  4] label:  12
step 30 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features:  [ 1  2 -3  4] label:  13
step 31 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features:  [ 1  2  3 -4] label:  14
step 32 file: b'file1.csv:1' raw_data: b',2,3,4,11' features:  [-1  2  3  4] label:  11
step 33 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features:  [ 1 -2  3  4] label:  12
step 34 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features:  [ 1  2 -3  4] label:  13
step 35 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features:  [ 1  2  3 -4] label:  14
step 36 file: b'file2.csv:1' raw_data: b',2,3,4,21' features:  [-1  2  3  4] label:  21
step 37 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features:  [ 1 -2  3  4] label:  22
step 38 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features:  [ 1  2 -3  4] label:  23
step 39 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features:  [ 1  2  3 -4] label:  24
step 40 file: b'file2.csv:1' raw_data: b',2,3,4,21' features:  [-1  2  3  4] label:  21
step 41 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features:  [ 1 -2  3  4] label:  22
step 42 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features:  [ 1  2 -3  4] label:  23
step 43 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features:  [ 1  2  3 -4] label:  24
step 44 file: b'file1.csv:1' raw_data: b',2,3,4,11' features:  [-1  2  3  4] label:  11
step 45 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features:  [ 1 -2  3  4] label:  12
step 46 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features:  [ 1  2 -3  4] label:  13
step 47 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features:  [ 1  2  3 -4] label:  14
step 48 file: b'file1.csv:1' raw_data: b',2,3,4,11' features:  [-1  2  3  4] label:  11
step 49 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features:  [ 1 -2  3  4] label:  12

读取二进制文件的固定长度的记录

从二进制文件中读取固定长度记录,需要 FixedLengthRecordReaderdecode_raw

这里使用的二进制文件是 CIFAR-10 数据集。文件格式是:每条记录的长度都是固定的,一个字节的标签,后面是3072( 3 × 32 × 32 )字节的图像数据。下载 CIFAR-10 数据集文件(CSDN不能设置免费,可以从 CIFAR-10 官网下载)

具体代码如下:

import tensorflow as tf
from PIL import Image

# 生成文件名字符串张量列表
data_dir = "cifar10_data/cifar-10-batches-bin"
filename_queue = [data_dir + ("/data_batch_%d.bin" % i) for i in range(1, 6)]
filename_queue = tf.train.string_input_producer(filename_queue)

label_bytes = 1
height = 32
width = 32
depth = 3
image_bytes = height * width * depth
# 每条记录长度
record_bytes = label_bytes + image_bytes

reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
key, value = reader.read(filename_queue)
# 把byte类型转成uint8类型
record = tf.decode_raw(value, tf.uint8)
# 提取标签,uint8类型转成int32类型
label = tf.cast(tf.strided_slice(record, [0], [label_bytes]), tf.int32)
# 提取图片数据,并 reshape
depth_major = tf.reshape(tf.strided_slice(record, [label_bytes], [record_bytes]), [depth, height, width])
# reshape 为 height × width × depth
uint8image = tf.transpose(depth_major, [1, 2, 0])

with tf.Session() as sess:
    # 将文件名填充到队列
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(10):
        key_run, label_run, image = sess.run([key, label, uint8image])
        print("step %02d file: %s label: %d" % (i, key_run, label_run))
        # 显示图片
        image = Image.fromarray(image)
        image.show()

    coord.request_stop()
    coord.join(threads)

输出结果:

扫描二维码关注公众号,回复: 3731459 查看本文章
step 00 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:0' label: 8
step 01 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:1' label: 5
step 02 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:2' label: 0
step 03 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:3' label: 6
step 04 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:4' label: 9
step 05 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:5' label: 2
step 06 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:6' label: 8
step 07 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:7' label: 3
step 08 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:8' label: 6
step 09 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:9' label: 2

参考tensorflow 中文社区 读取数据

猜你喜欢

转载自blog.csdn.net/qq_24548569/article/details/81747863