深度学习中基于tensorflow_slim进行复杂模型训练三之TFRecords的介绍

版权声明: https://blog.csdn.net/hh_2018/article/details/83539645

一、TFRecords的数据结构

TFRecords数据集是一种二进制的数据集,是tensorflow推荐的标准文件格式。Tensorflow通过ProtocolBuffers定义了TFRecords文件中存储的记录及其所含的字段结构,使用该方式可以将数据,标签以及和数据相关的信息通过key,value的形式存储在同一个文件中,并通过key,value的形式对存储的数据进行读取。该数据结构定义在tensrflow/core/example目录下的example.proto和feature.proto文件中,因此在构建实例时我们将转化后的张量称为样例,其内部记录称为特征域。

关于TFRecords的具体结构如下:

example = tf.train.Example(

features=tf.train.Features(

feature={

}))

其中example就是 样例,其中包含一个Features类型的数结构其命名为features,一个Features类型的数据结构又包含一个feature,feature中是多个key,value结构的数据,key是一个字符型数据,value是一个Feature型数据。

message Features {

map <string, Feature> feature = 1;

}

message Feature{

one of kind {

ByteList bytes_list = 1;

FloatList float_list = 2;

Int64List int64_list = 3 ;

}

}

二、TFRecords的写入方式(读入图片数据为例,存储对应的图片内容,宽,高,标签)

写入TFRecords数据时需主要分三步:

1. 定义对应的数据结构:

def image_to_tfexample(image_data, image_format, height, width, class_id):

return tf.train.Example(

features=tf.train.Features(

feature={

'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),

'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_format])),

'image/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=class_id)),

'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),

'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),

}))

2. 定义数据的读取方式:该步骤主要是将原始的数据读入并转化成相应的Bytes或者Int64的格式放入到对应的结构中,本例中读入的是图片,所以读取的方式如下:

for i in range(len(path)):

image_data = tf.gfile.FastGFile(path[i], 'rb').read()

image = tf.image.decode_jpeg(image_data)

image = sess.run(image)

print(image.shape)

height = image.shape[0]

width = image.shape[1]

class_id = [i]

3. 定义TFRecords生成的名字、写入和关闭文件:在该部分首先要在循环外面使用

writer = tf.python_io.TFRecordWriter('image_test.tfrecord')定义一个输出的文件名。

接着在循环中使用下面的语句将数据写入:

example = image_to_tfexample(image_data, b'jpg', height, width, class_id)

writer.write(example.SerializeToString())

最后使用writer.close()关闭文件。

通过上述运行之后会在目录中出现一个名为image_test.tfrecord的二进制文件,该文件中存储了所有相关的信息。

三、 TFRecords数据的读取。

关于该数据在读取时主要分为两部分:

1. 构造读取结构:该结构是<key, values>形式的数据,并且key的名字要和生成该文件时的名字相同,values的类型要和生成时对应的values类型一致,其中FixedLenFeature的输入为特征形状和特征数据的类型。

keys_to_features = {

'image/encoded': tf.FixedLenFeature([], tf.string, default_value=''),

'image/format': tf.FixedLenFeature([], tf.string, default_value=''),

'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),

'image/height': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),

'image/width': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),

}

2. 构建读取器和解析器(将样例转换为张量)

filename_queue = tf.train.string_input_producer(['image_test.tfrecord'], num_epochs=2)

reader = tf.TFRecordReader()

_, serialized_example = reader.read(filename_queue)

# 将一条序列化的样例转化为其包含所有特征的向量

features = tf.parse_single_example(

serialized_example,

features=keys_to_features

)

其中TFRecordReader()表示用来读取tfrecord格式的数据,其对应的read方法要求传入的是一个queue。得到对应的样例, tf.parse_single_example的作用是将输入的一个样例按照传入的features字典的形式转化成对应的张量。

至此就完成了对TFRecords数据的存储和读取,由于读取出的数据是tensor,因此要使用sess.run()的方式对数据进行显示,同样也可以对读取的数据进行和正常的数据同样的操作。

四、读取数据的操作

example = sess.run(features)

image1 = tf.image.decode_jpeg(example['image/encoded'], channels=3)

print(image1.shape)

此时image1的形状为[ ?, ?,3]

image2 = tf.image.resize_images(image1, [160, 160], method=1)

print(image2.shape)

此时image2的形状是 [160, 160, 3]

height = example['image/height']

print(height)

image = sess.run(image1)

print(image.shape)

此时image1的形状是 [250, 196, 3] 原始的图像大小。

# image = sess.run(image2)

# print(image2.shape)

print(example)

path = "image1/" + str(height) + str(i) + ".jpg"

misc.imsave(path, image)

对图片进行存储。

我们运行eaxmple时可以发现其结果是我们定义的<key, values>结构数据,通过key可以读出其对应的value,但是对于image/encoded而言,其对应的值是二进制数据,因此需要是使用decode_jpeg的方式将其进行解码,但是解码后的图片结构只是一个tensor,此时没有具体的大小,因此我们可以人为的确定其大小,如果不想改变其大小,直接运run()可以按照原有的大小恢复图片。

由于read方法读入的是一个队列,因此关于如何使用队列和线程进行数据的读取可以参考https://blog.csdn.net/hh_2018/article/details/81143109

猜你喜欢

转载自blog.csdn.net/hh_2018/article/details/83539645