TensorFlow 高性能数据输入管道设计指南

作者:黑暗星球
原文地址:https://blog.csdn.net/u014061630/article/details/80776975





TensorFlow版本:1.12.0

本篇主要介绍怎么使用 tf.data API 来构建高性能的输入 pipeline。

tf.data官方教程详见前面的博客<<<<<<<<<<tf.data官方教程


GPU、TPU的使用能够从根本上减少单个训练step所需的时间。但优异的性能不仅依赖于高速的计算硬件,也要求有一个高效的输入管道(Input Pipeline Performance Guide),这个管道在当前step完成前,进行下一个 step 需要的数据的准备。 tf.data API 对于灵活且高效的输入管道的建立非常有帮助。这个文档解释了 tf.data API 的特性,并介绍了构建高性能的 TensorFlow 数据输入管道的过程。

本文主要包含以下内容:

  • 介绍数据输入管道的结构(本质是一个 ETL 过程)。
  • tf.data 中,优化数据输入管道的常用方法。
  • 介绍了数据操作顺序对数据输入管道性能的影响。
  • 优异的数据输入管道应该具备的一些特质。

1. 数据输入管道的结构

TensorFlow数据输入管道可以被抽象为一个 ETL 过程(Extract,Transform,Load):

  • Extract:从硬盘上读取数据 ------ 可以是本地(HDD 或 SSD),也可以是网盘(GCS 或 HDFS)
  • Transform:使用 CPU 去解析、预处理数据 ------ 比如:图像解码、数据增强、变换(比如:随机裁剪、翻转、颜色变换)、打乱、batching。
  • Load:将 Transform 后的数据加载到 计算设备 ------ 例如:GPU、TPU 等设备。

上述的数据输入管道使用 CPU 来进行数据的 ETL 过程,从而让 GPU、TPU 等设备专心进行模型的训练过程(提高了设备的利用率)。另外,将数据输入管道抽象为 ETL 过程,有利于我们对数据输入管道进行优化。

当使用 tf.estimator.Estimator API 时,input_fn 需要完成 Extract 和 Transform 两个阶段。

def parse_fn(example):
  "Parse TFExample records and perform simple data augmentation."
  example_fmt = {
    "img_encoded": tf.FixedLenFeature((), tf.string, ""),
    "img_label": tf.FixedLenFeature((), tf.int64, -1)
  }
  parsed = tf.parse_single_example(example, example_fmt)
  image = tf.image.decode_image(parsed["img_encoded"])
  return image, parsed["img_label"]

def input_fn(batch_size):
files = tf.data.Dataset.list_files("/path/to/dataset/train-.tfrecord")
ds = files.interleave(tf.data.TFRecordDataset, cycle_length=1)
ds = ds.shuffle(buffer_size=batch_size4)
ds = ds.map(map_func=parse_fn)
ds = ds.batch(batch_size=batch_size)
return ds

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

注:后面的过程以上面的 input_fn 为基础。

2. 数据输入管道的性能优化

新计算设备使得网络的训练越来越快,所以我们必须细心地设计 CPU 上运行的数据输入管道(防止其成为系统的性能瓶颈)。tf.data 提供了数据输入管道所需的各种部件,借助其,我们可以实现高效的数据输入管道,优化 ETL 过程的各个步骤。

2.1 数据准备(ET)和消耗过程(L)的解耦 ------ prefetch解耦、重叠两过程

在执行一个训练 step 之前,你必须 Extract、Transform 训练数据,然后将它馈送给计算设备。在以前,当 CPU 为计算准备数据时,计算设备处于闲置状态;当计算设备执行训练 step 时,CPU 处于闲置状态。因此,单个训练 step 的时间等于 CPU 准备数据的时间 + 计算设备执行训练 step 的时间。
这里写图片描述
Pipelining 将训练 step 中的 数据准备模型执行 “并行”。当计算设备在执行第 N 个训练 step 时,CPU 为第 N+1 个训练 step 准备数据。通过两个过程的重叠,单个训练 step 的时间等于 CPU 准备数据的时间 和 计算设备执行训练 step 的时间中较大值。
这里写图片描述
tf.data.Dataset.prefetch 提供了 software pipelining 机制。该函数解耦了 数据产生的时间 和 数据消耗的时间。具体来说,该函数有一个后台线程和一个内部缓存区,在数据被请求前,就从 dataset 中预加载一些数据(进一步提高性能)。prefech(n) 一般作为最后一个 transformation,其中 nbatch_size

prefetch 的使用方法如下:

dataset = dataset.batch(batch_size=FLAGS.batch_size)
dataset = dataset.prefetch(buffer_size=FLAGS.prefetch_buffer_size) # last transformation
return dataset

  
  
  • 1
  • 2
  • 3

注意:只要你的 数据产生过程 和 数据消耗过程 可以重合(无论重合多少),那么 prefetch 就能为你带来性能提升。

2.2 数据变换(T)的并行化 ------ 并行map,融合mapbatch

使用 tf.data.Dataset.map,我们可以很方便地对数据集中的各个元素进行预处理。因为输入元素之间时独立的,所以可以在多个 CPU 核心上并行地进行预处理。map 变换提供了一个 num_parallel_calls 参数去指定并行的级别。例如,下图为 num_parallel_calls=2map 变换的示意图:
这里写图片描述

num_parallel_calls 参数的最优值取决于你的硬件、训练数据的特质(比如:它的 size、shape)、map 函数的计算量 和 CPU 上同时进行的其它处理。比较简单的一个设置方法是:将 num_parallel_calls 设置为 CPU 的核心数。例如,CPU 有四个核心时,将 num_parallel_calls 设置为 4 将会很高效。相反,如果 num_parallel_calls 大于 CPU 的核心数,将导致低效的调度,导致输入管道的性能下降。

map 变换开启并行化的方法如下:

dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)

  
  
  • 1

另外,如果你的 batch_size 比较大(成百上千),以 batch 的形式进行并行能够带来额外的性能提高。为此,tf.data 提供了 tf.contrib.data.map_and_batch 函数,其高效地融合了 mapbatch 两个变换。

为了融合 mapbatch 两个变换,我们只需要将:

dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)
dataset = dataset.batch(batch_size=FLAGS.batch_size)

  
  
  • 1
  • 2

改为:

dataset = dataset.apply(tf.contrib.data.map_and_batch(
    map_func=parse_fn, batch_size=FLAGS.batch_size))

  
  
  • 1
  • 2

2.3 数据读取(E)的并行化------并行地读取并解析多个数据文件

在实际应用中,输入数据可能被存储在网盘(例如,GCS 或 HDFS)(要么因为输入数据不适合本地,要么因为训练是分布式的,在每台机器上复制输入数据是没有意义的)。另外,在本地能够很好的读取数据的数据输入管道也可能会卡在 I/O 瓶颈上,因为 本地 和 远程存储 有以下区别:

  • Time-to-first-byte(读取第一个bytes的时间):从远程存储读取文件的第一个字节的时间比本地存储长一个数量级。
  • Read throughput(读取吞吐量):虽然远程存储通常提供大的聚合带宽,但是读取单个文件可能仅能利用该带宽的一小部分。

另外,一旦原始字节被读取到内存中,也可能需要对数据进行反序列化或解密(例如:protobuf),这将导致额外的负载。不管数据是本地存储还是远程存储,该开销都存在,但如果数据未被高效地预加载,则远程情况下可能更糟。

为了减轻各种数据读取(E)开销的影响,tf.data 提供了 tf.contrib.data.parallel_interleave 函数。该函数可以并行地从多个文件中提取并解析数据。同时读取的文件数可以通过参数 cycle_length 来指定。

下图说明了将 parallel_interleave 中的 cycle_length=2 时的效果:
这里写图片描述
为了并行地读取数据(E),只需将:

dataset = files.interleave(tf.data.TFRecordDataset) # 现在该函数已经加了cycle_length参数

  
  
  • 1

改为:

dataset = files.apply(tf.contrib.data.parallel_interleave(
    tf.data.TFRecordDataset, cycle_length=FLAGS.num_parallel_readers))

  
  
  • 1
  • 2

远程存储系统的吞吐量会受负载、网络事件等影响。为了缓解这种影响,可以将 parallel_interleaveprefetching 结合使用(详情见:tf.contrib.data.parallel_interleave

默认情况下,parallel_interleave 函数为元素提供一个确定性顺序,以方便再现。作为 prefetching 的一个替代方案(这在某些情况下,可能不高效),parallel_interleave 变换也提供了一个选项去提高性能(代价是元素的顺序的确定性)。尤其是,如果 sloppy 参数被设置为 True,变换可能偏离设定的顺序,通过临时跳过在下一个元素被请求时元素不可用的文件。

3. 数据输入管道性能的进一步优化

tf.data API 是围绕可组合的变换设计的(为用户提供灵活性)。虽然这些变换中的很多变换的次序是可交换的,但某些变换的次序对性能有影响。

3.1 Map and Batch ------ map开销很小时,batch形式的map更高效

将用户自定义的函数传给 map 函数 会产生调度、执行用户自定义函数的负载。一般情况下,这个负载与自定义函数的计算量相比很小。但是,如果 map 的函数的计算量很小,这个负载将是主要开销。在这种情况下,我们推荐使用向量化的自定义函数(它一次对一个batch进行变换),并且在 map 变换前使用 batch 变换。

3.2 Map and Cache ------ 通过cache进一步加速

tf.data.Dataset.cache 变化能够在内存或本地存储器上缓存一个数据集。如果传递给 map 变换的用户自定义函数的计算量很大,只要得到的数据集仍然适合内存或本地存储,就可以在 map 转换之后应用 cache 转换。

如果用户定义函数导致存储数据集需要的空间超过了 cache 的容量,考虑提前对数据集进行预处理,以减少资源的使用。

注意:cache将数据集进行缓存能够有效地提高数据输入管道的性能,但是,cache位置放置错误时,会导致模型性能下降。

3.3 Map and Interleave / Prefetch / Shuffle ------ 变换的顺序对内存使用量的影响

因为各个变换函数(包括 interleave,prefetch,shuffle)都有自己的内部缓存,所以如果传给 map 变换的 用户自定义函数 改变了元素的 size,那么 map 变换的次序影响内存的使用量。通常情况下,我们建议选择内存使用量更低的次序,除非不同的次序能够产生性能上的提高(例如,为了使用融合的 tf.contrib.data.map_and_batch)。

3.4 Repeat and Shuffle ------ repeat在前性能优,shuffle在前次序强

tf.data.Dataset.repeat 变换重复输入数据有限次(或无限次);数据的每一次重复称为一个 epoch。tf.data.Dataset.shuffle 变换随机打乱数据集 example 的次序。

如果 repeat 变换被放在 shuffle 变换之前,那么 epoch 边界将变得模糊。也就是说,某些元素可以在其他元素出现一次之前重复。另一方面,如果在 repeat 变换之前应用 shuffle 变换,那么在每个 epoch 开始时,性能可能会下降(因为这时,也需要进行 shuffle 变化的初始化)。换句话说,将 repeat 放置在 shuffle 之前,提供了更好的性能,将 shuffle 放置在 repeat 之前,提供了更强的次序保证。

当可能时,我们推荐使用融合op:tf.contrib.data.shuffle_and_repeat 变换,这个变换在性能和更强的次序保证上都是最好的(good performance and strong ordering guarantees)。否则,我们推荐在 repeat 之前使用 shuffle

4. 数据输入管道的最优实现

下面是设计最优数据输入管道的建议:

  • 使用 prefetch 函数去重叠 数据读取器 和 数据消耗器的工作。我们尤其推荐在输入管道的末端添加 prefetch(n) (n是batch size),以重叠 CPU 上的变换 及 GPU/TPU设备上的训练。详见【2.1】
  • 通过设置 num_parallel_calls 参数,来并行 map 变换。我们建议使用将该参数设置为 CPU 的核心数。详见【2.2】
  • 如果你使用 batch 变换来将预处理好的元素 batching,我们建议使用融合op:map_and_batch 变换;尤其是你如果使用大的batch size。详见【2.2】
  • 如果你的数据存在远程存储上,(且有时需要解析),我们建议使用 parallel_interleave 来并行数据的读取和解析。详见【2.3】
  • 将简单的用户自定义函数进行向量化,然后传递给 map 变换去分摊 用户自定义函数有关的调用、执行的负载。详见【3.1】
  • 如果你的数据能够加载到内存,使用 cache 变化去在训练的第一个 epoch 将数据集缓存到内存,所以能避免后来的 epoch 读取、解析、变换数据的负载。详见【3.2】
  • 如果你的预处理会增加你数据的 size,我们建议你首先使用 interleaveprefetchshuffle 变换去减少内存使用量(如果可能)。详见【3.3】
  • 我们建议在 repeat 变换之前使用 shuffle 变换,最好使用融合op: shuffle_and_repeat 变换。详见【3.4】

英文版本见:https://tensorflow.google.cn/performance/datasets_performance

        </div>
					<link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-a47e74522c.css" rel="stylesheet">
            </div>
								
				<script>
					(function(){
						function setArticleH(btnReadmore,posi){
							var winH = $(window).height();
							var articleBox = $("div.article_content");
							var artH = articleBox.height();
							if(artH > winH*posi){
								articleBox.css({
									'height':winH*posi+'px',
									'overflow':'hidden'
								})
								btnReadmore.click(function(){
									if(typeof window.localStorage === "object" && typeof window.csdn.anonymousUserLimit === "object"){
										if(!window.csdn.anonymousUserLimit.judgment()){
											window.csdn.anonymousUserLimit.Jumplogin();
											return false;
										}else if(!currentUserName){
											window.csdn.anonymousUserLimit.updata();
										}
									}
									
									articleBox.removeAttr("style");
									$(this).parent().remove();
								})
							}else{
								btnReadmore.parent().remove();
							}
						}
						var btnReadmore = $("#btn-readmore");
						if(btnReadmore.length>0){
							if(currentUserName){
								setArticleH(btnReadmore,3);
							}else{
								setArticleH(btnReadmore,1.2);
							}
						}
					})()
				</script>




TensorFlow版本:1.12.0

猜你喜欢

转载自blog.csdn.net/Eartha1995/article/details/84888475
今日推荐