tensorflow 读取TFRecord格式数据并进行计算代码

import tensorflow as tf
def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

if __name__=="__main__":
   filename0="file0.tfrecords"
   writer=tf.python_io.TFRecordWriter(filename0)
   for index in range(10):
     example=tf.train.Example(features=tf.train.Features(feature={
           'v1':_int64_feature(index),
            'v2':_int64_feature(index+1)}))
     writer.write(example.SerializeToString())
write.close()

filename1="file.tfrecords"
writer=tf.python_io.TFRecordWriter(filename1)
for index in range(10,20):
  example=tf.train.Example(features=tf.train.Features(feature={
    'v1': _int64_feature(index),
    'v2':_int64_feature(index+1)}))
  writer.write(example.SerializeToString())
writer.close()


filename_queue =tf.train.string_input_producer(["file0.tfrecords",
                                                "file1.tfrecords"],shuffle=True,num_epochs=2)
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)
features=tf.parse_single_example(
    serialized_example,
    features={'v1':tf.FixedLenFeature([],tf.int64),'v2':tf.FixedLenFeature([],tf.int64)}
)

v1=tf.cast(features['v1'],tf.int32)
v2=tf.cast(features['v2'],tf.int32)

v_mul=tf.multiply(v1,v2)

init_op=tf.global_variables_initializer()
local_init_op=tf.local_variables_initializer()

sess=tf.Session()
sess.run(init_op)
sess.run(local_init_op)

coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)

try:
    while not coord.should_stop():
        value1,value2,mul_result=sess.run([v1,v2,v_mul])
        print("%f\t%f\t%f"%(value1,value2,mul_result))
except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')
finally:
    coord.request_stop()
coord.join(threads)
sess.close()

猜你喜欢

转载自blog.csdn.net/liumoude6/article/details/82823595