TFRecords数据

TFRecords是tensorflow官网提供的一种二进制文件,它能方便的进行数据复制 移动 和更好的利用内存,同时不需要单独的标签文件(在读取数据文件是自动添加标签,下面有介绍);在训练时,使用TFRecords中数据的流程:首先生成xxx.tfrecord文件,接着使用input pipeline读取xxx.tfrecords文件/其他支持格式,then读取并解码数据,随机乱序(shuffle),生成文件序列(batch);最后输入到模型中。

如果有一串jpg图片地址和相应的标签:Images和 Labels

1. 生成TFRecords

存入TFRecords文件需要数据先存入名为example的protocol buffer,然后将其serialize成为string才能写入。example中包含features,用于描述数据类型:bytes,float,int64;具体来说,TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。

message Example {
    Features features = 1;
};
message Features {
    map<string, Feature> feature = 1;
};
message Feature {
    oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
    }
};
# -*- coding: utf-8 -*-
import os 
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

cwd = "E:/Anaconda3/tensorflow/Dataset/data/"
classes = {'cats', 'dogs'} #预先自己定义的类别
#将数据转化TFRecord文件对应的属性
def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
# 开始将数据写入TFRecord文件(xxx.tfrecord)
train_filename = 'tensorflow/train.tfrecords'   # 输出文件地址
# 创建一个writer来写TFRecords文件(写TFRecords <==> 输出TFRecords文件)
writer = tf.python_io.TFRecordWriter(train_filename) #输出成tfrecord文件


for index, name in enumerate(classes):     # 从classes中自动获取类别 (label)
    class_path = cwd + name + '//'
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name    #每张picture的绝对地址

        img = Image.open(img_path)
        img = img.resize((640, 320))
        img_raw = img.tobytes()  #将图片转化为二进制格式
        # 创建一个属性(feature)
        example = tf.train.Example(features = tf.train.Features(feature = {
                                                                                      "label":_int64_feature(index), 
                                                                           "img_raw":_bytes_feature(img_raw),                                                                          
                                                                           }))
        # 将上面的example protocol buffer 写入文件
        writer.write(example.SerializeToString())  #序列化为字符串
writer.close()

输入: 数据文件路径 path

输出: xxx.tfrecords文件

2. 读取TFRecord 文件

(1). 用tf.train.string_input_producer 读取tfrecords文件(xxx.tfrecords)的list建立文件名队列(FIFO序列),同时,可以申明num_epoches和shuffle参数表示需要读取数据的次数以及时候将tfrecords文件读入顺序打乱;结果:图像路径list

(2). 定义TFRecordReader读取(1)中的序列(图像路径list)返回下一个record;结果:serialize example和feature字典

(3). 用tf.parse_string_example对读取的TFRecords文件进行解码,抽取((2) serialize example和feature字典)中,返回feature对应的值,此时对应的值都是string,需要经过tf.decode(...) 和 tf.cast(...)等操作,将string类型的图像数据还原原始图像;同时也可以进行一些preprocessing操作;

(4). 利用tf.train.shuffle_batch(...)和tf.train.batch(...)将(3)中还原原始图像生成batch图像序列

扫描二维码关注公众号,回复: 2927794 查看本文章
#读取文件
def read_and_decode(filename,batch_size):
    #根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })
 
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [300, 300, 3])                #图像归一化大小
   # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5   #图像减去均值处理,根据自己的需要决定要不要加上
    label = tf.cast(features['label'], tf.int32)        
 
    #特殊处理,去数据的batch,如果不要对数据做batch处理,也可以把下面这部分不放在函数里
 
    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                    batch_size= batch_size,
                                                    num_threads=64,
                                                    capacity=200,
                                                    min_after_dequeue=150)
    return img_batch, tf.reshape(label_batch,[batch_size])

在读取到队列中后,数据输出之前还要作解码的操作从,可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量;

输入:XXX.tfrecords   batch_size

输出: image_batch  label_batch

3. 扩展

由于tf.train()函数在graph中增加了tf.train.QueueRunner类(在线程中运行线程中的队列数据),tf.train.start_queue_runner启动所有graph中的线程;用tf.train.Coordinator来管理线程(启动多少线程  何时终止线程...)

    # initialize global & local variables
    init_op =  tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
    sess.run(init_op)
    # create a coordinate and run queue runner objects
    # 启动多线程处理数据
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for batch_index in range(3):
        batch_images, batch_labels = sess.run([images, labels])
        for i in range(10):
            plt.imshow(batch_images[i, ...])
            plt.show()
            print "Current image label is: ", batch_lables[i]
    # close threads  结束线程
    coord.request_stop()
    coord.join(threads)
    sess.close()

4. 如何显示xxx.tfrecords文件中的图片

tfrecords_file = 'E:/Anaconda3/tensorflow//dataset/train.tfrecords'
Batch_size = 6
image_batch, label_batch = read_and_decode(tfrecords_file,Batch_size)

with tf.Session()  as sess:

    i = 0
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    try:
        while not coord.should_stop() and i<1:
            # just plot one batch size
            image, label = sess.run([image_batch, label_batch])
            for j in np.arange(4):
                print('label: %d' % label[j])
                plt.imshow(image[j,:,:,:])
                plt.show()
            i+=1
    except tf.errors.OutOfRangeError:
        print('done!')
    finally:
        coord.request_stop()
    coord.join(threads)

batch_size这里可以大家任意设定,显示几幅图片都可以,这里设置为6 同时i 控制显示张数

5. 完整代码

# -*- coding: utf-8 -*-
import os 
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

cwd = "E:/Anaconda3/tensorflow/dataset/data/"
classes = {'cats', 'dogs'}
writer = tf.python_io.TFRecordWriter('train.tfrecords')

def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))

for index, name in enumerate(classes):
    class_path = cwd + name + '//'
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name    #每张图片的绝对地址

        img = Image.open(img_path)
        img = img.resize((640, 320))
        img_raw = img.tobytes()  #将图片转化为二进制格式
        example = tf.train.Example(features = tf.train.Features(feature = {
                                                                           "label":_int64_feature(index),
                                                                           "img_raw":_bytes_feature(img_raw),                                                                          
                                                                           }))
        writer.write(example.SerializeToString())  #序列化为字符串
writer.close()

def read_and_decode(filename, batch_size): # read train.tfrecords
    filename_queue = tf.train.string_input_producer([filename])# create a queue

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)#return file_name and file
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' :tf.FixedLenFeature([],tf.string),
                                       })#return image and label

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [208, 208, 3])  #reshape image to 512*80*3

    label = tf.cast(features['label'], tf.int32) #throw label tensor

    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                    batch_size= batch_size,
                                                    num_threads=64,
                                                    capacity=2000,
                                                    min_after_dequeue=1500,
                                                    )
    return img_batch, tf.reshape(label_batch,[batch_size])


tfrecords_file = 'D:/Anaconda3/tensorflow/dataset/train.tfrecords'
Batch_size = 6
image_batch, label_batch = read_and_decode(tfrecords_file, Batch_size)

with tf.Session()  as sess:

    i = 0
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    try:
        while not coord.should_stop() and i<1:
            # just plot one batch size
            image, label = sess.run([image_batch, label_batch])
            for j in np.arange(BATCH_SIZE):
                print('label: %d' % label[j])
                plt.imshow(image[j,:,:,:])
                plt.show()
            i+=1
    except tf.errors.OutOfRangeError:
        print('done!')
    finally:
        coord.request_stop()
    coord.join(threads)

       6. 参考文献

      1. https://blog.csdn.net/u012222949/article/details/72875281   有imageFile 和 labelFile, 将imageFile和 labelFile分成train_set  test_set   

        2. https://blog.csdn.net/wiinter_fdd/article/details/72835939      imageFile_train  + class{} 类别自动生成     + imageFile_test

        3. https://blog.csdn.net/gybheroin/article/details/79800679        同上

        4. http://www.cnblogs.com/arkenstone/p/7507261.html             结构特别清晰

        5. https://www.cnblogs.com/Charles-Wan/p/6197019.html         读取数据分类清晰

猜你喜欢

转载自blog.csdn.net/qinghange/article/details/82086127