Best practice of ResNet high-accuracy pre-trained model in MMDetection

1 Introduction

As the most common backbone network, ResNet plays a crucial role in object detection algorithms. Many classic target detection algorithms, such as RetinaNet, Faster R-CNN and Mask R-CNN, all use ResNet as the backbone network and are optimized on this basis. At the same time, most of the subsequent improved algorithms will use RetinaNet, Faster R-CNN and Mask R-CNN as the baseline for fair comparison.

Recently, both TIMM  and  TorchVision  have announced the latest training techniques to improve the performance of ResNet. The scheme is called  ResNet Strikes Back (rsb) in TIMM, which improves the top1 accuracy of ResNet50 from 76.1 to 80.4 on the ImageNet 1k dataset, while TorchVision calls it TorchVision New Recipes (tnr), which makes top1 accurate The rate has increased to 80.86, both of which have increased considerably.

With such a strong pre-trained ResNet backbone, will applying it to downstream object detection tasks bring huge improvements? This is a very worthwhile question. To this end, the MMDetection team has provided a good answer to this question through extensive experiments and parameter tuning. Taking Faster R-CNN as an example, the performance table on the COCO Val dataset is as follows:

No. 1 is the Faster R-CNN baseline. It can be seen that based on the high-precision pre-trained ResNet model r50-mmcls, after the optimization of the optimizer, learning rate and weight attenuation coefficient, the mAP performance on Faster R-CNN can be improved by up to 3.4 ( r50-mmcls refers to the pre-trained model trained on MMClassification using the rsb strategy). At the same time, we searched a set of optimal parameters for each backbone, which is convenient for users to refer to.

2 Comparison of training strategies between rsb and tnr on ResNet50

This article will first carefully analyze and explain the training strategies of rsb and tnr, and then describe how to fine-tune in downstream object detection tasks to greatly improve the performance of classical detection models.

2.1 Summary table

First of all, for the convenience of viewing and comparison, we have sorted out the following comparison table:

  • ResNet50-base refers to the ResNet50 baseline result
  • ResNet50-rsb refers to the training results of the ResNet Strikes Back strategy proposed by TIMM, specifically the A1 strategy
  • ResNet50-tnr refers to the training results of the New Recipe strategy proposed by TorchVision
  • ResNet50-Deit-S refers to the result of training ResNet based on the Deit-S algorithm strategy used in TIMM. This experiment is for a fair comparison between DeiT-S and ResNet Strikes Back

2.2 Details of ResNet baseline training techniques

ResNet baseline is the ResNet50-base column of the above table. Note that there are two versions of ResNet for historical reasons: ResNet-PyTorch and ResNet-Caffe, the difference is the Bottleneck module, Bottleneck is a 1x1-3x3-1x1 stack structure, in caffe mode mode stride=2 parameter is placed in the first 1x1 volume product, while stride=2 in pyorch mode is placed at the second 3x convolution. A simple example is as follows:

if self.style == 'pytorch': 
      self.conv1_stride = 1 
      self.conv2_stride = stride 
else: 
      self.conv1_stride = stride 
      self.conv2_stride = 1 
复制代码

而此处的 baseline 则是指的 ResNet-PyTorch 。ResNet50是在 ImageNet 1K 训练数据集上从头训练,并在 ImageNet 1K 验证集上计算 top-1 accuracy。其训练技巧如下所示:

  • batch size: 32*8, 8卡,每张卡 32 bs
  • 优化器: SGD 且 Momentum 为 0.9
  • 学习率:初始学习率为 0.1, 每 30 个epoch 学习速率衰减为原来的 0.1
  • Epoch 总数:90
  • 权重正则: weight decay 为 1e-4
  • 训练数据增强
    • 随机缩放裁剪(RandomResizedCrop)
    • 随机水平翻转(RandomHorizontalFlip)
    • 随机颜色抖动 (ColorJitter)
  • 图片输入大小: 训练和测试时图像大小均为 224

基于上述配置,ResNet50 在 ImageNet 1k 验证数据集上 top-1 accuracy 是 76.1。

2.3 TIMM 训练技巧详情

TIMM 总结了目前最新的训练技巧,并将其应用到 ResNet 中,提出了 ResNet-rsb 版本。其有三个变种,分别对应 epochs 600, 300 和 100,称为 A1、A2 和 A3 版本,如下所示:

  • A1 是为了提供 ResNet50 上最佳性能模型
  • A2 是为了和 DeiT 进行相似对比(不是完全公平对比,因为 bs/训练 trick 不一样)
  • A3 是为了和原始 ResNet50 进行公平对比

作者在三个数据集上进行评估,具体为:

  • Val 表示在 ImageNet 1k 验证数据集
  • v2 表示 ImageNet 1k v2 版本数据集

以 A1 为例,其训练技巧如下所示:

  • batch size: 512x4=2048, 4卡,每张卡 512 bs
  • 优化器: LAMB
  • 学习率:初始学习率为 5x10^-3, 学习率调度策略采用 consine
  • Epoch 总数:600
  • 权重正则: weight decay 为 0.01
  • Wramup:总共 5 epoch
  • 训练数据增强
    • 随机缩放裁剪(RandomResizedCrop)
    • 随机水平翻转(RandomHorizontalFlip)
    • 随机增强 Rand Augment 7/0.5
    • Repeated Aug
    • Mixup Aug,参数 alpha 0.2
    • Cutmix Aug,参数 alpha 1.0
  • Loss 不再是采用 CE,而是替换为 BCE
  • 训练模型扰动
    • Label smoothing,参数 0.1
    • Stochastic-Depth, 参数 0.05
  • 图片输入大小
    • 训练输入网络的图片大小为 224x224
    • 基于 FixRes 策略,将图片 Resize 为 236, 然后 crop 成 224

可以看出,相比 ResNet-base 版本,由于训练 epoch 变长,训练中引入了很多新的数据增强和模型扰动策略。基于上述策略重新训练 ResNet50,在 ImageNet 1k 验证数据集上 top-1 accuracy 是 80.4。除了以上结果,作者还通过实验还得到了其他发现:

  • 加入如此多且强的数据增强和模型扰动,虽然可以提升模型性能,但是在网络训练早期收敛速度会很慢
  • 如果训练总 batch 为 512 时候,SGD 和 AdamW 都可以收敛,但是当训练的总 batch 为 2048,如果采用 SGD 和 BCE Loss,很难收敛

作者提供的非常详细的对比表如下所示:

同时,作者还验证 A1、A2 和 A3 这套设置在不同架构下的泛化能力。

其中加号表示 TorchVision 结果,而 ∗ 来自 DeiT 结果。 作者还对 ResNet-50 和 Deit-S 两者进行了对比,性能如下:

2.4 TorchVison 训练技巧详情

TorchVision 也推出了自己的训练技巧,其官方推文中有详细说明,其余相关讨论见 github.com/pytorch/vis…,最终结果如下所示:

作者还贴心地绘制了每个 trick 所带来的提升,如下所示:

训练技巧汇总:

  • batch size: 128x8=1024, 8卡,每张卡128 bs
  • 优化器: SGD 且 Momentum 为 0.9
  • 学习率:初始学习率为 0.5, 学习率调度策略采用 consine
  • Epoch 总数:600
  • 权重正则: weight decay 为 2e-05,且 norm 不进行 decay
  • Wramup:总共 5 epoch,采用线性 warmup,lr_warmup_decay 为 0.01
  • 训练数据增强
    • 随机缩放裁剪(RandomResizedCrop)
    • 随机水平翻转(RandomHorizontalFlip)
    • TrivialAugment
    • Mixup,参数 alpha 为 0.2
    • Cutmix,参数 alpha 为 1.0
    • 随机擦除 (Random Erase),概率参数为 0.1
  • 训练模型扰动
    • Label smoothing,参数 0.1
    • EMA,decay 参数为 0.99998,每隔 32 次迭代更新一次
  • 图片输入大小
    • 训练输入网络的图片大小为 176x176
    • 基于 FixRes 策略,对图片 Resize 为 232, 然后 crop 成 224

可以看出,rsb 和 torchvision 所提策略的重点都在于引入强的 aug、更多的模型扰动已经更长的训练 epoch。除此之外,作者还通过实验还得到了其他发现:

  • 使用一些更复杂的优化器,例如 Adam、RMSProp 和 SGD with Nesterov momentum,发现效果不会更好,但是作者没有实验 LAMB
  • 作者尝试了不同的 LR 调度器方案,例如 StepLR 和 Exponential。 尽管后者倾向于与 EMA 一起更好地工作,但它通常需要额外的超参数,例如定义最小 LR 才能正常工作,所以作者最终还是采用了对超参不那么敏感的 cosine
  • 作者尝试了不同的增强策略,例如 AutoAugment 和 RandAugment,但是这些都没有优于更简单的无参数 TrivialAugment
  • 使用双三次或最近邻插值并没有提供比双线性更好的结果
  • 使用 Sync Batch Norm 并没有比使用常规 Batch Norm 产生明显更好的结果
  • Mixup 和 Cutmix 两者配合使用时可以采用等概率的随机选择一种的方式,单独采用 Mixup 可以提升0.118,配合 Cutmix 可以额外提升 0.278
  • FixRes 中作者发现,训练时采用 176 图片尺寸,测试采用 272 尺寸效果最好,不过作者还是采用 224 ,目的是为了 baseline 保持一致,而如果训练时候采用 224 尺寸,测试采用 256 效果最好

3 高性能预训练模型在目标检测任务上的表现

本节探讨高性能预训练模型在目标检测任务上的表现。本实验主要使用 COCO 2017 数据集在 Faster R-CNN FPN 1x 上进行。具体设置请参考 MMDetection 配置文件 

# https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py 
_base_ = [ 
    '../_base_/models/faster_rcnn_r50_fpn.py', 
    '../_base_/datasets/coco_detection.py', 
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' 
] 
复制代码

几个核心配置为:

  • 8 卡训练,总 batch size 为 16
  • 1x 训练时长即 12 epoch
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 
复制代码
  • 优化器相关配置是: SGD+ 0.9 momentum,lr 为 0.02,weight_decay 为 0.0001

如果想理解 Faster R-CNN 代码及其配置参数等细节信息可以参考 轻松掌握 MMDetection 中常用算法(二):Faster R-CNN|Mask R-CNN 一文。

3.1 仅替换预训练权重下表现

为了快速评估不同性能的预训练权重在 Faster R-CNN FPN baseline 配置下的性能,我们直接替换预训练权重,验证在 Faster R-CNN 上的性能,结果如下所示:

模型下载链接:
download.pytorch.org/models/resn…
download.openmmlab.com/mmclassific…
github.com/rwightman/p…
download.pytorch.org/models/resn…

需要说明的是,为了保证实验的公平性,我们在实验中设置了随机种子 (Seed=0),全部实验均在 8 x V100上进行,batch size = 16(8×2)。

从上表可以看出:替换成高精度的预训练权重的 ResNet 后,Faster R-CNN 没有显著提升甚至有些性能下降非常严重,这说明高精度预训练的 ResNet 可能不再适合用同一套超参,故而非常有必要对其进行参数调优。主要可能因为预训练模型的训练策略调整使 SGD 优化器不能很好适应预训练模型。 因此我们计划通过调整优化器、学习率和权重正则来对检测器进行微调。

3.2 ResNet baseline 预训练模型参数调优实验

由于 ResNet Strikes Back 中使用 AdamW 优化器来训练,我们尝试在目标检测下游任务中使用 AdamW 作为优化器,希望能够达到和使用 SGD 优化器相同的测试精度。

具体细节可见下表:

可以看到,在使用 AdamW 优化器,学习率为 0.0001 时,整体精度均可以超过 SGD 优化器,而在权重正则为 0.1 时,性能最优。

3.3 mmcls rsb 预训练模型参数调优实验

通过修改配置文件中预训练模型,我们可以将 ResNet 的预训练模型替换为 MMClassification 通过 rsb 训练出的预训练模型。在此基础上,我们分别通过 AdamW 与 SGD 来训练 Faster R-CNN ,从而获得 MMClassification 通过 rsb 训练出的预训练模型在检测任务上的效果。MMDetection 中配置文件写法为:

_base_ = [ 
    '../_base_/models/faster_rcnn_r50_fpn.py', 
    '../_base_/datasets/coco_detection.py', 
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' 
] 
 
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a1-600e_in1k_20211228-20e21305.pth'   # noqa 
model = dict( 
    backbone=dict( 
        init_cfg=dict( 
            type='Pretrained', prefix='backbone.', checkpoint=checkpoint))) 
 
# 此处配置参数是最佳性能参数 
optimizer = dict( 
    _delete_=True, 
    type='AdamW', 
    lr=0.0002, 
    weight_decay=0.05, 
    paramwise_cfg=dict(norm_decay_mult=0., bypass_duplicate=True))             
复制代码

基于上一小节的先验,我们首先使用 AdanW 为优化器,学习设置为 0.0001。

具体数值见下表:

为了验证学习率对精度的影响,我们做了学习率验证实验。

具体数值见下表:

基于上述实验,我们发现在学习率为 0.0002 时,检测精度明显提高,因此我们设置了学习率为 0.0002 的对照实验:

具体数值见下表:

能够看到,在 lr=0.0002, weight decay=0.05 时,精度最高。同时也可以发现,weight decay 在某一个区间范围内对精度的影响不会很大,一旦超过这个区间,精度会下降明显

3.4 TIMM rsb 预训练模型参数调优实验

接下来,我们将 ResNet 的预训练模型替换为 PyTorch Image Models (TIMM) 的模型。在此基础上,我们通过 AdamW 来训练 Faster R-CNN ,从而获得 TIMM 预训练模型在检测任务上的效果。MMDetection 中的配置写法如下所示:

_base_ = [ 
    '../_base_/models/faster_rcnn_r50_fpn.py', 
    '../_base_/datasets/coco_detection.py', 
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' 
] 
 
checkpoint = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth'   # noqa 
model = dict( 
    backbone=dict( 
        init_cfg=dict( 
            type='Pretrained', checkpoint=checkpoint))) 
 
# 此处配置参数是最佳性能参数 
optimizer = dict( 
    _delete_=True, 
    type='AdamW', 
    lr=0.0002, 
    weight_decay=0.03, 
    paramwise_cfg=dict(norm_decay_mult=0., bypass_duplicate=True))      
 
复制代码

基于上述微调先验信息,我们首先分别固定学习率为 0.0001 和 0.0002 ,调整 weight decay。实验结果如下:

\

具体数值见下表:

可以看到,尽管相比于基础的 Bbox mAP=37.4,有了一定的提高,最高能够达到 39.8。但是相比于使用 mmcls 的预训练模型得到的最高 Bbox mAP = 40.8 还是有一定的差距。之后我们还调整学习率来观察结果:

具体数值见下表:

综合前面结果,能够看到,AdamW 在学习率为 0.0001 和 0.0002 时精度差距不大,超过 0.0003 后,精度会明显下降。

3.5 TorchVision tnr 预训练模型参数调优实验

最后,我们还将 ResNet 的预训练模型替换为 TorchVision 通过新技巧训练出来的高精度模型,并分别通过 SGD 与 AdamW 来训练 Faster R-CNN,从而获得 TorchVision 通过新技巧训练出来的高精度模型在检测任务上的效果。MMDetection 中配置文件写法如下所示:

_base_ = [ 
    '../_base_/models/faster_rcnn_r50_fpn.py', 
    '../_base_/datasets/coco_detection.py', 
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' 
] 
 
checkpoint = 'https://download.pytorch.org/models/resnet50-11ad3fa6.pth' 
model = dict( 
 backbone=dict( 
 init_cfg=dict( 
 type='Pretrained', checkpoint=checkpoint))) 
 
# 此处配置参数是最佳性能参数             
optimizer = dict( 
    _delete_=True, 
    type='AdamW', 
    lr=0.0001, 
    weight_decay=0.1, 
    paramwise_cfg=dict(norm_decay_mult=0., bypass_duplicate=True))      
 
复制代码

我们首先使用 SGD 算法来优化 Faster R-CNN,并尝试搜索最优的学习率与 weight decay:

SGD 算法下固定 weight decay 搜索最优 learning rate 实验

具体数值见下表:

SGD 算法下固定 learning rate 搜索最优 weight decay 实验

具体数值见下表:

根据实验结果可以看到,当保持训练参数一致,仅将预训练模型换为 TorchVision 的高精度预训练模型可以使精度上涨 2.2(37.4 -> 39.6) 个点。当学习率为 0.04,weight decay 为 0.00001 时,使用 r50-tnr 作为预训练模型,在 SGD 算法下优化的 Faster R-CNN 可以达到最高的 39.8% mAP 的结果。

接下来,我们尝试使用 AdamW 算法优化模型:

AdamW 算法下固定 weight decay 搜索最优 learning rate 实验

具体数值见下表:

AdamW 算法下固定 learning rate 搜索最优 weight decay 实验

具体数值见下表:

通过实验可以得出,在使用 AdamW 优化器时,学习率为 0.0001 的效果要比 0.0002 好上很多。而 weight decay 在 0.1 左右达到最高,其变化对最终的结果影响不大。当学习率使用 0.0001,weight decay 为 0.1 时,加载 r50-tnr 的 Faster R-CNN 达到最大精度的 40.2% mAP,相比于 SGD 上升了 0.4 (39.8 -> 40.2)。

4 总结

Through the previous experiments, we can see that the use of high-precision pre-training models can greatly improve the effect of target detection. The highest results of all pre-training models and the corresponding parameter settings are shown in the following table:

As can be seen from the table, using any high-performance pre-trained model can improve the performance of the target detection task by about 2 points. Among them, the high-precision model trained by MMClassification increases Faster R-CNN by 3.4 points and reaches the highest mAP of 40.8% , which proves that the use of high-performance pre-training model is of great help to the target detection task.

If you want to reproduce or experiment further, you can refer to the relevant configuration files and PRs

Welcome to the  MMDetection  experience, thanks to  the MMClassification  team for carefully proofreading the content of this article!

If our sharing has brought you some help, please like, subscribe, and follow~

Guess you like

Origin juejin.im/post/7086749765727158309