TensorFlow入门(十-III)tfrecord 图片数据 读写

版权声明:欢迎转载,但请务必注明原文出处及作者信息。 https://blog.csdn.net/Jerr__y/article/details/78594786

本例代码:https://github.com/yongyehuang/Tensorflow-Tutorial/tree/master/python/the_use_of_tfrecord

关于 tfrecord 的使用,分别介绍 tfrecord 进行三种不同类型数据的处理方法。
- 维度固定的 numpy 矩阵
- 可变长度的 序列 数据
- 图片数据

在 tf1.3 及以后版本中,推出了新的 Dataset API, 之前赶实验还没研究,可能以后都不太会用下面的方式写了。这些代码都是之前写好的,因为注释中都写得比较清楚了,所以直接上代码。

tfrecord_3_img_writer.py

# -*- coding:utf-8 -*- 

import tensorflow as tf
import numpy as np
from tqdm import tqdm
import sys
import os
import time

'''tfrecord 写入数据.
将图片数据写入 tfrecord 文件。以 MNIST png格式数据集为例。

首先将图片解压到 ../../MNIST_data/mnist_png/ 目录下。
解压以后会有 training 和 testing 两个数据集。在每个数据集下,有十个文件夹,分别存放了这10个类别的数据。
每个文件夹名为对应的类别编码。

现在网上关于打包图片的例子非常多,实现方式各式各样,效率也相差非常多。
选择合适的方式能够有效地节省时间和硬盘空间。
有几点需要注意:
1.打包 tfrecord 的时候,千万不要使用 Image.open() 或者 matplotlib.image.imread() 等方式读取。
 1张小于10kb的png图片,前者(Image.open) 打开后,生成的对象100+kb, 后者直接生成 numpy 数组,大概是原图片的几百倍大小。
 所以应该直接使用 tf.gfile.FastGFile() 方式读入图片。
2.从 tfrecord 中取数据的时候,再用 tf.image.decode_png() 对图片进行解码。
3.不要随便使用 tf.image.resize_image_with_crop_or_pad 等函数,可以直接使用 tf.reshape()。前者速度极慢。
4.如果有固态硬盘的话,图片数据一定要放在固态硬盘中进行读取,速度能高几十倍几十倍几十倍!生成的 tfrecord 文件就无所谓了,找个机械盘放着就行。
'''

# png 文件路径
TRAINING_DIR = '../../MNIST_data/mnist_png/training/'
TESTING_DIR = '../../MNIST_data/mnist_png/testing/'
# tfrecord 文件保存路径,这里只保存一个 tfrecord 文件
TRAINING_TFRECORD_NAME = 'training.tfrecord'
TESTING_TFRECORD_NAME = 'testing.tfrecord'

DICT_LABEL_TO_ID = {  # 把 label(文件名) 转为对应 id
    '0': 0,
    '1': 1,
    '2': 2,
    '3': 3,
    '4': 4,
    '5': 5,
    '6': 6,
    '7': 7,
    '8': 8,
    '9': 9,
}


def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def int64_feature(values):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[values]))


def convert_tfrecord_dataset(dataset_dir, tfrecord_name, tfrecord_path='../data/'):
    """ convert samples to tfrecord dataset.
    Args:
        dataset_dir: 数据集的路径。
        tfrecord_name: 保存为 tfrecord 文件名
        tfrecord_path: 保存 tfrecord 文件的路径。
    """
    if not os.path.exists(dataset_dir):
        print(u'png文件路径错误,请检查是否已经解压png文件。')
        exit()
    if not os.path.exists(os.path.dirname(tfrecord_path)):
        os.makedirs(os.path.dirname(tfrecord_path))
    tfrecord_file = os.path.join(tfrecord_path, tfrecord_name)
    class_names = os.listdir(dataset_dir)
    n_class = len(class_names)
    print(u'一共有 %d 个类别' % n_class)
    with tf.python_io.TFRecordWriter(tfrecord_file) as writer:
        for class_name in class_names:  # 对于每个类别
            class_dir = os.path.join(dataset_dir, class_name)  # 获取类别对应的文件夹路径
            file_names = os.listdir(class_dir)  # 在该文件夹下,获取所有图片文件名
            label_id = DICT_LABEL_TO_ID.get(class_name)  # 获取类别 id
            print(u'\n正在处理类别 %d 的数据' % label_id)
            time0 = time.time()
            n_sample = len(file_names)
            for i in range(n_sample):
                file_name = file_names[i]
                sys.stdout.write('\r>> Converting image %d/%d , %g s' % (
                    i + 1, n_sample, time.time() - time0))
                png_path = os.path.join(class_dir, file_name)  # 获取每个图片的路径
                # CNN inputs using
                img = tf.gfile.FastGFile(png_path, 'rb').read()  # 读入图片
                example = tf.train.Example(
                    features=tf.train.Features(
                        feature={
                            'image': bytes_feature(img),
                            'label': int64_feature(label_id)
                        }))
                serialized = example.SerializeToString()
                writer.write(serialized)
    print('\nFinished writing data to tfrecord files.')


if __name__ == '__main__':
    convert_tfrecord_dataset(TRAINING_DIR, TRAINING_TFRECORD_NAME)
    convert_tfrecord_dataset(TESTING_DIR, TESTING_TFRECORD_NAME)

tfrecord_3_img_reader.py

# -*- coding:utf-8 -*- 

import tensorflow as tf

'''read data
从 tfrecord 文件中读取数据,对应数据的格式为png / jpg 等图片数据。
'''

# **1.把所有的 tfrecord 文件名列表写入队列中
filename_queue = tf.train.string_input_producer(['../data/training.tfrecord'], num_epochs=None, shuffle=True)
# **2.创建一个读取器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# **3.根据你写入的格式对应说明读取的格式
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'image': tf.FixedLenFeature([], tf.string),
                                       'label': tf.FixedLenFeature([], tf.int64)
                                   }
                                   )
img = features['image']
# 这里需要对图片进行解码
img = tf.image.decode_png(img, channels=3)  # 这里,也可以解码为 1 通道
img = tf.reshape(img, [28, 28, 3])  # 28*28*3

label = features['label']

print('img is', img)
print('label is', label)
# **4.通过 tf.train.shuffle_batch 或者 tf.train.batch 函数读取数据
"""
这里,你会发现每次取出来的数据都是一个类别的,除非你把 capacity 和 min_after_dequeue 设得很大,如
X_batch, y_batch = tf.train.shuffle_batch([img, label], batch_size=100,
                                          capacity=20000, min_after_dequeue=10000, num_threads=3)
这是因为在打包的时候都是一个类别一个类别的顺序打包的,所以每次填数据都是按照那个顺序填充进来。
只有当我们把队列容量舍得非常大,这样在队列中才会混杂各个类别的数据。但是这样非常不好,因为这样的话,
读取速度就会非常慢。所以解决方法是:
1.在写入数据的时候先进行数据 shuffle。
2.多存几个 tfrecord 文件,比如 64 个。
"""

X_batch, y_batch = tf.train.shuffle_batch([img, label], batch_size=100,
                                          capacity=200, min_after_dequeue=100, num_threads=3)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

# **5.启动队列进行数据读取
# 下面的 coord 是个线程协调器,把启动队列的时候加上线程协调器。
# 这样,在数据读取完毕以后,调用协调器把线程全部都关了。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
y_outputs = list()
for i in xrange(5):
    _X_batch, _y_batch = sess.run([X_batch, y_batch])
    print('** batch %d' % i)
    print('_X_batch.shape:', _X_batch.shape)
    print('_y_batch:', _y_batch)
    y_outputs.extend(_y_batch.tolist())
print(y_outputs)

# **6.最后记得把队列关掉
coord.request_stop()
coord.join(threads)

猜你喜欢

转载自blog.csdn.net/Jerr__y/article/details/78594786
今日推荐