如何更快地训练Vision Transformer

点击上方“计算机视觉工坊”,选择“星标”

干货第一时间送达

ec4f7748200331fa0121e6118ced1d0f.png

作者丨zzk译

来源丨GiantPandaCV 

近期MetaAI发布了一篇博客,关于如何显著提升Vision Transformer的训练效率。

原文:[Significantly faster Vision Transformer training]

链接:https://ai.facebook.com/blog/significantly-faster-vision-transformer-training

What the research is

Vision Transformer模型几乎火遍计算机视觉各个领域,其性能随着参数增加和更久的训练过程而得到提升。随着模型越来越大,超过了万亿次浮点运算的规模,该领域达到了瓶颈:训练一个模型往往要一个月,需要几百上千个GPU,导致大部分人无法接触到大规模ViT模型,并进而增加了对加速器的需求。

为了降低门槛,让更多人能够应用ViT,我们开发一系列方法来加速整个训练。我们基于MetaAI的图像分类模型库PyCls实现了一系列优化,这些优化极大的提升了模型训练过程的吞吐量:

eb43c619a9a029c967e4c039f0c707bc.png

How it works ?

我们首先对代码库进行分析,以定位训练效率低下的原因,最后关注点落在计算类型上:大部分模型都是用FP32进行训练,如果使用FP16训练的话,可以降低显存占用,并提高模型训练速度,但这一做法经常会导致准确率下降

所以我们选了一个折中的方法:自动混合精度。在该方法下,我们用half类型进行计算,以加快训练,减少显存使用。并以fp32类型存储参数,以保证模型准确率。其中我们没有手动将网络各部分转换成half类型,而是应用AMP各种模式(如O1, O2, O3),以寻找一个既能提升速度又不影响精度的平衡点。

FSDP

为了让训练更加高效,我们应用了FSDP训练策略,他能够将参数,梯度,优化器状态分片到各GPU上。在FSDP的帮助下,我们可以用更少的GPU资源构建更大的模型。

FSDP策略可以参考 [数据并行Deep-dive: 从DP 到 Fully Sharded Data Parallel (FSDP)完全分片数据并行] 链接:https://zhuanlan.zhihu.com/p/485208899

MTA Optimizer

前向计算完毕后,优化器需要对各个参数进行修改。而当参数比较多的情况下,对应启动的Optimizer Kernel就会变得很多,通常这些Kernel都比较小,计算负担不大,启动Kernel的开销反而占了大头。

ContiguousParams中,它将模型参数放置到一块连续的显存中进行计算,这样就能减少优化器这部分的时间。下图是Resnet50+SGD是否应用ContiguousParams的比较,可以看到OptimizerStep这部分时间显著减少了。

0cd41421db147e6f6349a9d96ecad304.png

而NVIDIA的Apex库的做法则是在底层重新实现了一系列MultiTensorOptimizer,如Adam, Adagrad等等。

Apex这种方法比较硬核,普通用户如果想要自己自定义优化器并应用Multi Tensor的优化,就必须改动底层CUDA代码。而最近PyTorch也在计划提供了一系列foreach接口[Replace optimizers in torch.optim with the ones from torch.optim._multi_tensor] 链接:https://github.com/pytorch/pytorch/pull/49039,让用户只需要在Python层即可享受到优化,对应的MultiTensor版Momentum优化器代码如下所示:

torch._foreach_mul_(bufs, momentum)
torch._foreach_add_(bufs, grads, alpha=1 - dampening)

Pooled Classifier

原版的ViT是额外加了一个分类token,来输出最后的分类结果。而这里采用平均池化 如:https://github.com/facebookresearch/pycls/blob/main/pycls/core/config.py#L205 处理最后的分类

Batch Second Input Tensor Layout

这里的数据格式与以往不同,将batch维度放在第二维,并在调用nn.MultiheadAttention的时候,设置batch_first=False,以减少不必要的转置

if self.batch_first and is_batched:
    return attn_output.transpose(1, 0), attn_output_weights
else:
    return attn_output, attn_output_weights

总感觉这个实现怪怪的

其他优化

我们在采取560大小的batchsize下,达到了1.51倍的加速比,进一步的我们将batchsize设置为384,并将图片大小增大到256,达到了1.86倍加速比。在全FP16运算下,能够达到2.18倍加速比,尽管这偶尔会降低准确率(在实验中,准确率降低不到10%)。

2edb272d5f0f50e93175a6b67804a8b2.png

使用上述优化,我们将Imagenet1K数据集每epoch训练时间从0.65小时降低到0.43小时

4ef617de9202d28343b3f76a7a4fda08.png

我们还研究了不同GPU配置对训练速度的影响,在不同配置下我们都实现了比DDP baseline更高的吞吐量。随着GPU增加,吞吐量会因为设备之间的通信开销略微下降。然而即使在64块GPU下,我们仍然比DDP基线快1.83倍

d042117a38c2e97534015004e941b7a5.png

文中链接

PyCls :https://github.com/facebookresearch/pycls

ContiguousParams:https://github.com/PhilJd/contiguous_pytorch_params

Adam:https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu

本文仅做学术分享,如有侵权,请联系删文。

干货下载与学习

后台回复:巴塞罗自治大学课件,即可下载国外大学沉淀数年3D Vison精品课件

后台回复:计算机视觉书籍,即可下载3D视觉领域经典书籍pdf

后台回复:3D视觉课程,即可学习3D视觉领域精品课程

计算机视觉工坊精品课程官网:3dcver.com

1.面向自动驾驶领域的多传感器数据融合技术

2.面向自动驾驶领域的3D点云目标检测全栈学习路线!(单模态+多模态/数据+代码)
3.彻底搞透视觉三维重建:原理剖析、代码讲解、及优化改进
4.国内首个面向工业级实战的点云处理课程
5.激光-视觉-IMU-GPS融合SLAM算法梳理和代码讲解
6.彻底搞懂视觉-惯性SLAM:基于VINS-Fusion正式开课啦
7.彻底搞懂基于LOAM框架的3D激光SLAM: 源码剖析到算法优化
8.彻底剖析室内、室外激光SLAM关键算法原理、代码和实战(cartographer+LOAM +LIO-SAM)

9.从零搭建一套结构光3D重建系统[理论+源码+实践]

10.单目深度估计方法:算法梳理与代码实现

11.自动驾驶中的深度学习模型部署实战

12.相机模型与标定(单目+双目+鱼眼)

13.重磅!四旋翼飞行器:算法与实战

14.ROS2从入门到精通:理论与实战

15.国内首个3D缺陷检测教程:理论、源码与实战

重磅!计算机视觉工坊-学习交流群已成立

扫码添加小助手微信,可申请加入3D视觉工坊-学术论文写作与投稿 微信交流群,旨在交流顶会、顶刊、SCI、EI等写作与投稿事宜。

同时也可申请加入我们的细分方向交流群,目前主要有ORB-SLAM系列源码学习、3D视觉CV&深度学习SLAM三维重建点云后处理自动驾驶、CV入门、三维测量、VR/AR、3D人脸识别、医疗影像、缺陷检测、行人重识别、目标跟踪、视觉产品落地、视觉竞赛、车牌识别、硬件选型、深度估计、学术交流、求职交流等微信群,请扫描下面微信号加群,备注:”研究方向+学校/公司+昵称“,例如:”3D视觉 + 上海交大 + 静静“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进去相关微信群。原创投稿也请联系。

0ddfc52042602f04eff668c71f3c3c50.png

▲长按加微信群或投稿

d421f6833c0cee0bda2fc750897501e4.png

▲长按关注公众号

3D视觉从入门到精通知识星球:针对3D视觉领域的视频课程(三维重建系列三维点云系列结构光系列手眼标定相机标定、激光/视觉SLAM、自动驾驶等)、知识点汇总、入门进阶学习路线、最新paper分享、疑问解答五个方面进行深耕,更有各类大厂的算法工程人员进行技术指导。与此同时,星球将联合知名企业发布3D视觉相关算法开发岗位以及项目对接信息,打造成集技术与就业为一体的铁杆粉丝聚集区,近4000星球成员为创造更好的AI世界共同进步,知识星球入口:

学习3D视觉核心技术,扫描查看介绍,3天内无条件退款

70f573190d423a1b47b77baeb4934844.png

 圈里有高质量教程资料、答疑解惑、助你高效解决问题

觉得有用,麻烦给个赞和在看~

猜你喜欢

转载自blog.csdn.net/qq_29462849/article/details/124791760