创建自己训练数据集---TFrecords实战

1. 前言

     刚开始学习TensorFlow的时候都是跑官方的例子:比如我们熟知的:MNIST手写体数据集或者cifar-10/100数据集,这些数据集都是官方以及整理好了的,直接使用就好。但是后面实际上不可能总是官方的数据集跑一跑就能学好TensorFlow的,所以一些项目需要处理自己的数据,整理自己的数据集。做过Kaggle竞赛的应该很熟悉.csv文件了,.csv文件非常方便,但是通常读取的时候,是一次性读取到内存里面的.要是内存小的话,就要想其他的办法了,那就变得很麻烦了.

     或者有时候,从硬盘上面直接读取图片啊什么的,因为图片的文件格式,存放位置各种各样等等一些因素,要是想在训练阶段直接这么使用的话,就更加麻烦了.所以,对于数据进行统一的管理是很有必要的.TFRecord就是对于输入数据做统一管理的格式.加上一些多线程的处理方式,使得在训练期间对于数据管理把控的效率和舒适度都好于暴力的方法.

     小的任务什么方法差别不大,但是对于大的任务,使用统一格式管理的好处就非常显著了.因此,TFRecord的使用方法很有必要熟悉.

     自己通过查找很多资料,完成了一个简单的自己的图片通过TFRecord做成数据集以在TensorFlow运行的例子,现将其汇总,以便日后以往还能查找到。

2. TFrecord实战代码

#_*_ coding:utf-8 _*_
'''
    做过kaggle竞赛的应该很熟悉.csv文件了,.csv文件非常方便,但是通常读取的时候,是一次性读取到内存里面的.
要是内存小的话,就要想其他的办法了,那就变得很麻烦了. 或者有时候,从硬盘上面直接读取图片啊什么的,因为图片的文件格式,
存放位置各种各样等等一些因素,要是想在训练阶段直接这么使用的话,就更加麻烦了.所以,对于数据进行统一的管理是很有必要的.
TFRecord就是对于输入数据做统一管理的格式.加上一些多线程的处理方式,
使得在训练期间对于数据管理把控的效率和舒适度都好于暴力的方法.
小的任务什么方法差别不大,但是对于大的任务,使用统一格式管理的好处就非常显著了.因此,TFRecord的使用方法很有必要熟悉.

    本程序主要用于将自己的数据做成Tfrecords,以便tensorflow能够很好的进行划分batch和Tensor读取
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练
如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue,
也就是从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,
这里引入一种比较高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecords。

    TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。
TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。
我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串,
并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

    从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。
这个操作可以将Example协议内存块(protocol buffer)解析为张量。
'''

#导入相关的包
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import time

EPOCH = 5
#设置batch_size的大小
BATCH_SIZE = 5
#设置出队后最小剩余量
min_after_dequeue = 10
#设置队列的容量
capacity = min_after_dequeue + 4 * BATCH_SIZE
OUTPUT_SIZE = 440
# OUTPUT_SIZE = 630
#图片通道数为3,代表彩色
DEPTH = 3
#定义各个路径
data_path = 'data/'
tfrecords_path = 'tfrecords/'
#需要生产的.tfrecords数目
tfrecords_num = 6

#用于获取项目所在绝对路径
#cwd = D:\软件安装\pycharm\PyCharm Community Edition 2017.3.3\workplace\TFrecods_practicing
#这里可以获取cwd,也可以不用获取,反正我使用的是相对路径(relative path)也一样能行
cwd = os.getcwd()

#把传入的value转化为整数型的属性,int64_list对应着 tf.train.Example 的定义
def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
#把传入的value转化为字符串型的属性,bytes_list对应着 tf.train.Example 的定义
def _byte_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))

#生成TFRecords文件
def create_record(data_path,tfrecords_path,tfrecords_num):

    rows = 32
    cols = 32
    depth = DEPTH

    #第一个for循环确定要将自己的数据划分为多少个.tfrecords文件
    for i in range(tfrecords_num):
        # 先定义writer对象,writer负责将得到的记录写入TFRecords文件,此处有多个.tfrecords文件
        writer = tf.python_io.TFRecordWriter(tfrecords_path + str(i) + ".tfrecords")
        #os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。这个列表以字母顺序。 它不包括 '.' 和'..' 即使它在文件夹中。
        #一张一张的写入TFRecords文件(这里有15张照片,准备划分为5个.tfrecords,每三张划分到一个.tfrecords文件)
        #当然,.tfrecords的数目和每次划分的数据的数目的乘积不一定满足总的数据的数目
        #img_name:[]
        for img_name in os.listdir(data_path)[i*BATCH_SIZE:(i+1)*BATCH_SIZE]:
            '''
            img_name:
                1.JPG
                10.JPG
                11.jpg
                12.JPG
                13.jpg
                14.jpg
                15.JPG
                2.JPG
                3.JPG
                4.JPG
                5.JPG
                6.JPG
                7.JPG
                8.JPG
                9.JPG
            '''
            # 打开图片
            img_path = data_path + img_name
            img = Image.open(img_path)

            #对图片做一些预处理操作
            img = img.resize((OUTPUT_SIZE, OUTPUT_SIZE))
            # 设置裁剪参数
            h, w = img.size[:2]
            j, k = (h - OUTPUT_SIZE) / 2, (w - OUTPUT_SIZE) / 2
            box = (j, k, j + OUTPUT_SIZE, k + OUTPUT_SIZE)
            # 裁剪图片
            img = img.crop(box=box)

            # 将图片转化为原生bytes
            img_raw = img.tobytes()
            #使用tf.train.Example来封装我们的数据
            example = tf.train.Example(features=tf.train.Features(feature={
                'height': _int64_feature(rows),
                'width': _int64_feature(cols),
                'depth': _int64_feature(depth),
                'img_raw': _byte_feature(img_raw),
                'label': _int64_feature(i)
            }))
            #Example调用SerializeToString()方法将自己序列化并由
            #writer = tf.python_io.TFRecordWriter("train.tfrecords")对象保存,
            #最终是将所有的图片文件和label保存到同一个tfrecords文件中
            writer.write(example.SerializeToString())
    writer.close()
'''
基本的,一个Example中包含Features,Features里包含Feature(这里没s)的字典。
最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List
- tf.train.FloatList:列表每个元素为float。
- tf.train.Int64List:列表每个元素为int64。
- tf.train.BytesList:列表每个元素为string。
'''


'''
读取数据则以上过程的逆,先获取序列化数据,再解析:
一旦生成了TFRecords文件,接下来就可以使用队列(queue)读取数据了
TF多线程机制:
'''
def read_and_decode(filename):
    #读取tfrecords文件名到队列中,使用tf.train.string_input_producer函数,该函数可以接收一个文件名列表,
    #并自动返回一个对应的文件名队列filename_queue,之所以用队列是为了后续多线程考虑(队列和多线程经常搭配使用)
    filename_queue = tf.train.string_input_producer(filename)
    #实例化tf.TFRecordReader()类生成reader对象,接收filename_queue参数,并读取该队列中文件名对应的文件,
    reader = tf.TFRecordReader()
    #得到serialized_example(读到的就是.tfrecords序列化文件)
    _, serialized_example = reader.read(filename_queue)
    '''
    tf.parse_single_example函数,该函数能从serialized_example中解析出一条数据,
    这里tf.parse_single_example函数传入参数serialized_example和features,其中features是字典的形式,
    指定每个key的解析方式,比如image_raw使用tf.FixedLenFeature方法解析,这种解析方式返回一个Tensor,
    大多数解析方式也都是这种,另一种是tf.VarLenFeature方法,返回SparseTensor,用于处理稀疏数据,不再多提。
    这里还要注意必须告诉解析函数以何种数据类型解析,这必须与生成TFRecords文件时指定的数据类型一致。
    最后返回features是一个字典,里面存放了每一项的解析结果。
    最后只要读出features中的数据即可。比如,features['label'],features['pixels']。
    但要注意的是,此时的image_raw依然是字符串类型的,需要进一步还原成像素数组,
    用TF提供的函数tf.decode_raw来搞定:images = tf.decode_raw(features['image_raw'],tf.uint8)。

    '''
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           #这里解析图片的类型与生成TFrecords文件是指定images的类型得一致
                                           #所以解析的时候使用string类型,但是images解析出来是string,还需要
                                           #将其进一步还原为像素数组(无符号8位---uint8类型)
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })

    imgs = tf.decode_raw(features['img_raw'], tf.uint8)
    imgs = tf.reshape(imgs, [OUTPUT_SIZE, OUTPUT_SIZE, 3])
    #tf.cast(x, dtype, name=None) 对于传入的数据将其数据格式转换为dtype类型。
    imgs = tf.cast(imgs, tf.float32) * (1. / 255)
    labels = tf.cast(features['label'], tf.int32)

    '''
    上面已经得到了images和labels,这是一条数据,训练一次需要一个batch的数据,
    TF提供了tf.train.shuffle_batch函数,上述解析代码只要提供一次,
    然后将labels和images作为tf.train.shuffle_batch函数的参数,
    tf.train.shuffle_batch就能自动获取到一个batch的labels和images。
    tf.train.shuffle_batch函数获取batch的过程需要生成一个队列(加入计算图中),
    然后一个一个入队labels和images,然后出队组合batch.
    batch_size就是batch的大小,capacity指的是队列的容量,比如capacity设为1,而batch_szie为3,
    那么组成一个batch的过程中,出队的操作就会因为数据不足而频繁地被阻塞来等待入队加入数据,运行效率很低。
    相反,如果capacity被设置的很大,比如设为1000,而batch_size设置为3,那么入队操作在空闲时就会频繁入队,
    供过于求并非坏事,糟糕的是这样会占用很多内存资源,而且没有得到多少效率上的提升。
    还有一点值得注意,当使用tf.train.shuffle_batch时,为了使得shuffle效果好一点,出队后队列剩余元素必须得足够多,
    因为太少的话也没什么必要打乱了,因此tf.train.shuffle_batch函数要求提供min_after_dequeue参数来
    保证出队后队内元素足够多,这样队列就会等队内元素足够多时才会出队。
    显而易见,capacity必须大于min_after_dequeue。min_after_dequeue根据数据集大小和batch_size综合考虑,
    而capacity则通常设置为capacity= min_after_dequeue + 3*batch_size,在效率和资源占用之间取得平衡。
    '''

    #制作打乱顺序的batch
    img_batch, label_batch = tf.train.shuffle_batch([imgs, labels], batch_size=BATCH_SIZE,
                                                    capacity=capacity,
                                                    min_after_dequeue=min_after_dequeue)

    return img_batch,label_batch


#TFRecords_collection(模型汇总)
def TFRecords_collection():
    # 注意os.path.isfile只是判断传入的文件是否存在,而os.path.exists()则是判断文件或者文件夹是否存在
    # 这里判断.tfrecords文件是否已经写入成功
    if os.path.isfile(tfrecords_path + '0.tfrecords'):
        print('--------------.tfrecords文件已存在--------------')
    else:
        start_time = time.time()
        print('--------------开始制作tfrecords--------------')
        # 制作tfrecords
        create_record(data_path, tfrecords_path, tfrecords_num)
        print('------------制作结束:耗时%.2f seconds-----------' % (time.time() - start_time))

    # os.path模块主要用于文件的属性获取,os.path.join(path1[, path2[, ...]])将多个路径组合后返回,
    # 第一个绝对路径之前的参数将被忽略
    filenames = [os.path.join(tfrecords_path, '%d.tfrecords' % ii) for ii in range(tfrecords_num)]
    # 获取img_batch和label_batch
    img_batch, label_batch = read_and_decode(filenames)
    # 打印img_batch和label_batch的type,发现都是Tensor类型
    '''
    img_batch.type Tensor("shuffle_batch:0", shape=(5, 440, 440, 3), dtype=float32)
    label_batch.type Tensor("shuffle_batch:1", shape=(5,), dtype=int32)
    '''
    print('img_batch.type', img_batch)
    print('label_batch.type', label_batch)

    # 初始化所有的op
    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        # 启用队列,tf.train.start_queue_runners函数,这个函数中传入参数sess,就可以做到多线程训练
        threads = tf.train.start_queue_runners(sess=sess)

        for i in range(EPOCH):
            plt.figure(figsize=(15, 15))
            value, label = sess.run([img_batch, label_batch])
            print('第%d次打印标签:' % (i + 1), label)
            print('value.shape:', value.shape)
            print('value.type', type(value))
            for j in range(BATCH_SIZE):
                # 每个batch的照片数目:value.shape[0]
                plt.subplot(1, value.shape[0], (j + 1))
                plt.imshow(value[j:j + 1, :, :, :].reshape(OUTPUT_SIZE, OUTPUT_SIZE, 3))
            plt.show()
    #需要注意的是,如果需要对图片裁切大小的OUTPUT_SIZE进行更改的话,一定要将tfrecords文件夹的.tfrecords文件
    #全部删除,因为.tfrecoeds是上一次保存的,也就包含了上一次图片裁切大小的OUTPUT_SIZE的设定,
    #所以,如果不删除,则会导致读取了上一次的.tfrecords文件,而导致图片.reshape失败,从而程序报错
    #中间遇到一个错误就是这个,改了很久都没效果,所以要记住。


if __name__ == '__main__':
    TFRecords_collection()

3. 参考链接以及相关附注

     1. 项目的图片如下:
     
     2. 参考链接:

     1. Github:TF-learing/tfrecord/readme.md
     2. 简书:TensorFlow高效读取数据 | Ycszen-物语
     3. 十图详解tensorflow数据读取机制(附代码
     4. TensorFlow学习(十一):保存TFRecord文件
     5. TensorFlow教程:利用TensorFlow处理自己的数据
     6. TensorFlow动手玩:数据导入2

猜你喜欢

转载自www.cnblogs.com/Stoner/p/9051030.html
今日推荐