【Tensorflow】你可能无法回避的 TFRecord 文件格式详细讲解

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/briblue/article/details/80789608

如果你是 Tensorflow 的初学者,那么你或多或少在网络上别人的博客上见到过 TFRecord 的影子,但很多作者都没有很仔细地对它进行说明,这也许会让你感受到了苦恼。本文按照我自己的思路对此进行一番讲解,也许能够提供给你一些帮助。

TFRecord 是什么?

TFRecord 是谷歌推荐的一种二进制文件格式,理论上它可以保存任何格式的信息。

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

上面是 Tensorflow 的官网给出的文档结构。整个文件由文件长度信息、长度校验码、数据、数据校验码组成。

但对于我们普通开发者而言,我们并不需要关心这些,Tensorflow 提供了丰富的 API 可以帮助我们轻松读写 TFRecord 文件。

TFRecord 的核心内容在于内部有一系列的 Example ,Example 是 protocolbuf 协议下的消息体。

在这里我相信大家都对 protocolbuf 比较了解,如果不了解也没有关系,它本质上和 xml 及 json 没有多大的区别。

网上有很多 example 的简单说明。

message Example {
  Features features = 1;
};

message Features {
  // Map from feature name to feature.
  map<string, Feature> feature = 1;
};

message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

message BytesList {
  repeated bytes value = 1;
}
message FloatList {
  repeated float value = 1 [packed = true];
}
message Int64List {
  repeated int64 value = 1 [packed = true];
}

熟悉 protobuf 同学看到这个格式定义就能马上明白了,不熟悉的同学可以点击相关的文章,我之前的这篇有对 protocolbuf 作过详细解释。

一个 Example 消息体包含了一系列的 feature 属性。

每一个 feature 是一个 map,也就是 key-value 的键值对。

key 取值是 String 类型。

而 value 是 Feature 类型的消息体,它的取值有 3 种:

  1. BytesList
  2. FloatList
  3. Int64List

需要注意的是,他们都是列表的形式。

protocolbuf 是通用的协议格式,对主流的编程语言都适用。所以这些 List 对应到 python 语言当中是 列表,而对于 Java 或者 C/C++ 来说他们就是数组。

举个例子,一个 BytesList 可以存储 Byte 数组,因此像字符串、图片、视频等等都可以容纳进去。

所以 TFRecord 可以存储几乎任何格式的信息。

但需要说明的是,更官方的文档来源于 Tensorflow的源码,这里面有详细的定义及注释说明。

为什么要用 TFRecord ?

TFRecord 也不是非用不可,但它确实是谷歌官方推荐的文件格式。

1、它特别适应于 Tensorflow ,或者说它就是为 Tensorflow 量身打造的。
2、因为 Tensorflow开发者众多,统一训练时数据的文件格式是一件很有意义的事情。也有助于降低学习成本和迁移成本。

TFRecord 怎么用?

TFRecord 是一种文件格式,那么对于 TFRecord 文件的 IO 怎么处理呢?

事实上,Tensorflow 给我们提供了丰富的 API ,开发者运用这些 API 可以轻松地处理 TFRecord 文件。

创建 TFRecord 文件

我们可以利用 TFWriter 轻松完成这个任务。

但制作之前,我们要先明确自己的目的。

我们必须想清楚,要把什么信息存储到 TFRecord 文件当中,这其实是最重要的。

下面,举例说明。

因为深度学习很多都是与图片集打交道,那么,我们可以尝试下把一张张的图片转换成 TFRecord 文件。

首先定义 Example 消息体。

Example Message {
    Features{
        feature{
            key:"name"
            value:{
                bytes_list:{
                    value:"cat"
                }
            }
        }
        feature{
            key:"shape"
            value:{
                int64_list:{
                    value:689
                    value:720
                    value:3
                }
            }
        }
        feature{
            key:"data"
            value:{
                bytes_list:{
                    value:0xbe
                    value:0xb2
                    ...
                    value:0x3
                }
            }
        }
    }

}

上面的 Example 表示,要将一张 cat 图片信息写进 TFRecord 当中,而图片信息包含了图片的名字,图片的维度信息还有图片的数据,分别对应了 name、shape、content 3 个 feature。

这里写图片描述

下面,我们开始用代码实现它。

def write_test(input,output):
    ''' 借助于 TFRecordWriter 才能将信息写进 TFRecord 文件'''
    writer = tf.python_io.TFRecordWriter(output)

    # 读取图片并进行解码
    image = tf.read_file(input)
    image = tf.image.decode_jpeg(image)

    with tf.Session() as sess:
        image = sess.run(image)
        shape = image.shape
        # 将图片转换成 string。
        image_data = image.tostring()
        print(type(image))
        print(len(image_data))
        name = bytes("cat", encoding='utf8')
        print(type(name))

        # 创建 Example 对象,并且将 Feature 一一对应填充进去。
        example = tf.train.Example(features=tf.train.Features(feature={
            'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
            'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
            'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
        }
        ))

        # 将 example 序列化成 string 类型,然后写入。
        writer.write(example.SerializeToString())

        writer.close()

write_test('cat.jpg','cat.tfrecord')

运行上面的代码,就可以在当前目录生成 cat.tfrecord 文件。

上面代码注释都比较详细,我挑重点来讲。

  1. 将图片解码,然后转化成 string 数据,然后填充进去。
  2. Feature 的 value 是列表,所以要记得加 []
  3. example 需要调用 SerializetoString() 进行序列化后才行。

TFRecord 文件的读取

上一节是讲如何将一张图片的信息写入到一个 tfrecord 文件当中。

现在,我们需要检验它是否正确,这就需要用到如何读取 TFRecord 文件的知识点了。

def _parse_record(example_proto):
    features = {
        'name': tf.FixedLenFeature((), tf.string),
        'shape': tf.FixedLenFeature([3], tf.int64),
        'data': tf.FixedLenFeature((), tf.string)}
    parsed_features = tf.parse_single_example(example_proto, features=features)
    return parsed_features

def read_test(input_file):

    # 用 dataset 读取 tfrecord 文件
    dataset = tf.data.TFRecordDataset(input_file)
    dataset = dataset.map(_parse_record)
    iterator = dataset.make_one_shot_iterator()


    with tf.Session() as sess:
        features = sess.run(iterator.get_next())
        name = features['name']
        name = name.decode()
        img_data = features['data']
        shape = features['shape']
        print('=======')
        print(type(shape))
        print(len(img_data))

        # 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
        img_data = np.fromstring(img_data,dtype=np.uint8)
        image_data = np.reshape(img_data,shape)


        plt.figure()
        #显示图片
        plt.imshow(image_data)
        plt.show()

        #将数据重新编码成 jpg 图片并保存
        img = tf.image.encode_jpeg(image_data)
        tf.gfile.GFile('cat_encode.jpg','wb').write(img.eval())

read_test('cat.tfrecord')

代码比较简单,我也有给详细的注释,我挑重要的几点讲解一下。

  1. 我用 dataset 去读取 tfrecord 文件
  2. 在解析 example 的时候,用现成的 API 就好了 tf.parse_single_example
  3. 用 np.fromstring() 方法就可以获取解析后的 string 数据,记得数据格式还原成 np.uint8
  4. 用 tf.image.encode_jpeg() 方法可以将图片数据编码成 jpeg 格式。
  5. 用 tf.gfile.GFile 对象可以将图片数据保存到本地。
  6. 因为将图片 shape 写进了 example 中,解析的时候必须制定维度,在这里是 [3] ,不然程序报错。

运行程序后,可以看到图片显示正常.

这里写图片描述

并且将 TFRecord 中的图片数据也成功地保存到本地了。

一些疑问

Q:我的示例为什么用 Dataset 而不用大多数博文中的 QueueRunner 呢?

A:这是因为 Dataset 比 QueueRunner 新,而且是官方推荐的,Dataset 比较简单。

Q:学习了 TFRecord 相关知识,下一步学习什么?
A:可以尝试将常见的数据集如 MNIST 和 CIFAR-10 转换成 TFRecord 格式。

下一篇博文,我就将怎么将 CIFAR-10 转换成 TFRecord 格式人数据集,然后构建简单的神经网络去实验它。

猜你喜欢

转载自blog.csdn.net/briblue/article/details/80789608