tensorflow实战之TFRecord数据的写入与读取

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
#生成整数型属性
def int64List(values):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[values]))
#生成字符串型属性
def bytesList(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
#定义一个类用来写入
writer = tf.python_io.TFRecordWriter('path/mnist.tfrecord')
#加载数据
mnist = input_data.read_data_sets('data/')
#写入
num = mnist.train.num_examples
images = mnist.train.images
labels = mnist.train.labels
for index in range(num):
    images_raw = images[index].tostring() 
    label = labels[index]
    example = tf.train.Example(features=tf.train.Features(feature={
        'labels':int64List(label),
        'images':bytesList(images_raw)
    }))
    writer.write(example.SerializeToString())
writer.close()
#读取
reader = tf.TFRecordReader()
file_queue = tf.train.string_input_producer(['path/mnist.tfrecord'])
_,example_parse = reader.read(file_queue)
example = tf.parse_single_example(example_parse,
                                 features={
                                     'labels': tf.FixedLenFeature([],tf.int64),
                                     'images': tf.FixedLenFeature([],tf.string)
                                 })
image_data = example['images']
label = example['labels']
#数据解码
image = tf.decode_raw(image_data,tf.uint8)
label = tf.cast(label,tf.float32)

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    thresds = tf.train.start_queue_runners(sess=sess,coord=coord)
    for i in range(10):
        images,labels = sess.run([image,label])

猜你喜欢

转载自blog.csdn.net/hanyong4719/article/details/80375019