TensorFlow's tfrecords file combined with queue queue to read data method

Reprinted from: https://blog.csdn.net/liangjun_feng/article/details/79698809?spm=5176.9876270.0.0.39062ef1d8LNGc

Tensorflow is a mainstream deep learning framework. It is very convenient for beginners to type code step by step according to the tutorial, and the packaging is very good. But to be honest, this kind of excessive encapsulation is really maddening at a certain stage, especially when you write a small example yourself, you will often find it difficult to start. Official tutorials and various books always talk about processed datasets such as MNIST and CIFAR_10. There is no way to see the bottom details, and for beginners who don't know much about the framework, they don't know how the 3D image dataset is input. I tossed here for a few days, traveled all over CSDN, Zhihu, Jianshu, github, stackoverflow, stepped on 10,000 pits, and read almost all the available tutorials and possible bug solutions. Summarized a better implementation template and shared it with you. I hope that the latecomers will not be tortured for a long time like me.

TensorFlow provides a total of three methods for reading data on the official website, two of which need to read pictures into memory first. These methods are not suitable for large-scale data reading, because when the amount of data increases , the large amount of memory occupied will bring difficulties to the program running. The third method of reading data in combination with tfrecords file and queue queue introduced here will not be limited by memory. I will explain it to you in combination with specific data sets and programs.

First give the link of the data and the program. If you are not interested in the details, you can use it directly: 
Dataset:  YaleB_dataset 
Handler:  Batch Generator.py

The general idea of ​​using queue to read image data is divided into three steps: 
1. Generate a txt list according to the specific storage situation of the data set. The list records the storage address of each image and some related information (such as label, size, etc.) etc.) 
2. Read the data and information according to the list record in the first step, and write the data and information into Tensorflow's special file format (.tfrecords) according to a certain format 
3. Batch read from the .tfrecords file Get data to feed the model to use

Generation of data lists

The data list generated according to the data storage situation, the code written in different situations must be different, here is the process and program according to my specific situation. 
My data storage address is: /Users/zhuxiaoxiansheng/Desktop/doc/SICA_data The details of /YaleB 
are as follows: 

write picture description here

The Class01 of the first picture here represents the first category, and 00000 represents the first picture in the first category. The procedure for generating the list is as follows:

##相关库函数导入
import os
import cv2 as cv
import tensorflow as tf 
from PIL import Image
import matplotlib.pyplot as plt
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

def getTrianList():
    root_dir = "/Users/zhuxiaoxiansheng/Desktop/doc/SICA_data/YaleB"  #数据存储文件夹地址
    with open('/Users/zhuxiaoxiansheng/Desktop'+"/Yaledata.txt","w") as f:    #txt文件生成地址
        for file in os.listdir(root_dir):
            if len(file) == 23:                     #图片名长为23个字节,避免读入其他的文件
                f.write(root_dir+'/'+file+" "+ file[11:13] +"\n")   #file[11:13]表示类别编号
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

The resulting manifest file looks like this 

write picture description here

Generate tfrecords file

After getting the txt list file, you can enter the process steps according to this file. First, we need to generate the .tfrecords file. The code is as follows

def load_file(example_list_file):   #从清单中读取地址和类别编号,这里的输入是清单存储地址
   lines = np.genfromtxt(example_list_file,delimiter=" ",dtype=[('col1', 'S120'), ('col2', 'i8')])
   examples = []
   labels = []
   for example,label in lines:
       examples.append(example)
       labels.append(label)
   return np.asarray(examples),np.asarray(labels),len(lines)   

def trans2tfRecord(trainFile,output_dir):    #生成tfrecords文件
    _examples,_labels,examples_num = load_file(trainFile)
    filename = output_dir + '.tfrecords'
    writer = tf.python_io.TFRecordWriter(filename)
    for i,[example,label] in enumerate(zip(_examples,_labels)):
        example = example.decode("UTF-8")
        image = cv.imread(example)
        image = cv.resize(image,(192,168))    #这里的格式需要注意,一定要尽量保证图片的大小一致
        image_raw = image.tostring()          #将图片矩阵转化为字符串格式
        example = tf.train.Example(features=tf.train.Features(feature={
                'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
                'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))                        
                }))
        writer.write(example.SerializeToString()) 
    writer.close()     #写入完成,关闭指针
    return filename    #返回文件地址
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

What is generated here is that .tfrecords is not easy to open, so it will not be displayed.

Read data from tfrecords file

The function that sets the way to read files from the tfrecords file is as follows:

def read_tfRecord(file_tfRecord):     #输入是.tfrecords文件地址
    queue = tf.train.string_input_producer([file_tfRecord])
    reader = tf.TFRecordReader()
    _,serialized_example = reader.read(queue)
    features = tf.parse_single_example(
            serialized_example,
            features={
          'image_raw':tf.FixedLenFeature([], tf.string),   
          'label':tf.FixedLenFeature([], tf.int64)
                    }
            )
    image = tf.decode_raw(features['image_raw'],tf.uint8)
    image = tf.reshape(image,[192,168,3])
    image = tf.cast(image, tf.float32)
    image = tf.image.per_image_standardization(image)
    label = tf.cast(features['label'], tf.int64)   这里设置了读取信息的格式
    return image,label
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

test code

The above is the main code, and here is the code to test whether it is successful:

if __name__ == '__main__':
    getTrianList()
    dataroad = "/Users/zhuxiaoxiansheng/Desktop/Yaledata.txt"
    outputdir = "/Users/zhuxiaoxiansheng/Desktop/Yaledata"

    trainroad = trans2tfRecord(dataroad,outputdir)
    traindata,trainlabel = read_tfRecord(trainroad)
    image_batch,label_batch = tf.train.shuffle_batch([traindata,trainlabel],
                                            batch_size=100,capacity=2000,min_after_dequeue = 1000) 

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())  
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess,coord = coord)
        train_steps = 10  

        try:  
            while not coord.should_stop():  # 如果线程应该停止则返回True  
                example,label = sess.run([image_batch,label_batch])  
                print(example.shape,label)  

                train_steps -= 1  
                print(train_steps)  
                if train_steps <= 0:  
                    coord.request_stop()    # 请求该线程停止  

        except tf.errors.OutOfRangeError:  
            print ('Done training -- epoch limit reached')  
        finally:  
            # When done, ask the threads to stop. 请求该线程停止  
            coord.request_stop()  
            # And wait for them to actually do it. 等待被指定的线程终止  
            coord.join(threads)      
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

If successful, there will be the following output 

write picture description here

Although I have already made a lot of mistakes when I wrote this code, when it is applied to your own data set, there will definitely be some problems due to some specific situations. You are welcome to leave a message here when you have no choice, although I also won't come back

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324864736&siteId=291194637