tensorflow高效读取数据(tfrecord)

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon May 28 15:29:51 2018

"""
import tensorflow as tf
import numpy as np
from PIL import Image
import os
from scipy import misc
import time






temp_dir=os.getcwd()
files=os.listdir('ORL')
files_name=[temp_dir+'/ORL/'+x for x in files]
labels=[int(x.split('_')[-1].split('.')[0]) for x in files_name]




def create_record(files_name,labels,name):
    writer=tf.python_io.TFRecordWriter(name+'.tfrecord')
    for i in range(len(files_name)):
        img=misc.imread(files_name[i])
        img_raw=img.tobytes()                   #将图片转化为原生bytes
        example=tf.train.Example(features=tf.train.Features(feature={
                'imgs': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]]))}))#example:img+label 
        serialized=example.SerializeToString()   #序列化
        writer.write(serialized)    ##写入文件


    writer.close()
    
    
def read_record(filename,h,w):
    filename_quene=tf.train.string_input_producer([filename],shuffle=False)
   
    train_reader=tf.TFRecordReader()
    _,serialized_example=train_reader.read(filename_quene)
    
    
    features=tf.parse_single_example(serialized_example,features={
            'imgs': tf.FixedLenFeature([],tf.string),
            'label': tf.FixedLenFeature([],tf.int64) })
    
    img=tf.decode_raw(features['imgs'],tf.uint8)
    img=tf.reshape(img,[h,w])
    label=tf.cast(features['label'],tf.int32)
    
    return img,label




create_record(files_name,labels,'train')
img,label=read_record('train.tfrecord',112,92)


img_batch,label_batch=tf.train.shuffle_batch([img,label],batch_size=20,capacity=200,min_after_dequeue=100,num_threads=6)


init=tf.global_variables_initializer()




with tf.Session() as sess:
    sess.run(init)
    coord=tf.train.Coordinator()   #创建一个协调器,管理线程
    threads=tf.train.start_queue_runners(sess=sess,coord=coord) #启动QueueRunner, 此时文件名队列已经进队
    for i in range(200):
    
        img,label=sess.run([img_batch,label_batch])
        print(img)
        print(label)
        print('----------------------------',i,'--------------------------')
        time.sleep(1)
    coord.request_stop()
    coord.join(threads)


    
    
    
    

在运行时报错

OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 10, current size 0)

经过检查发现是读入的图片是(112,92)的,自己reshape成了(112,92,3)的。

修改后无误。

猜你喜欢

转载自blog.csdn.net/zlrai5895/article/details/80393919