import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import matplotlib.pyplot as plt
#生成整数型属性 def int64List(values): return tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) #生成字符串型属性 def bytesList(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) #定义一个类用来写入 writer = tf.python_io.TFRecordWriter('path/mnist.tfrecord') #加载数据 mnist = input_data.read_data_sets('data/') #写入 num = mnist.train.num_examples images = mnist.train.images labels = mnist.train.labels for index in range(num): images_raw = images[index].tostring() label = labels[index] example = tf.train.Example(features=tf.train.Features(feature={ 'labels':int64List(label), 'images':bytesList(images_raw) })) writer.write(example.SerializeToString()) writer.close() #读取 reader = tf.TFRecordReader() file_queue = tf.train.string_input_producer(['path/mnist.tfrecord']) _,example_parse = reader.read(file_queue) example = tf.parse_single_example(example_parse, features={ 'labels': tf.FixedLenFeature([],tf.int64), 'images': tf.FixedLenFeature([],tf.string) }) image_data = example['images'] label = example['labels'] #数据解码 image = tf.decode_raw(image_data,tf.uint8) label = tf.cast(label,tf.float32) with tf.Session() as sess: coord = tf.train.Coordinator() thresds = tf.train.start_queue_runners(sess=sess,coord=coord) for i in range(10): images,labels = sess.run([image,label])