ICLR 2023 | RevCol:可逆的多 column 网络,大模型架构设计新范式

我们给神经网络架构增加了一个维度!

自 ViT 时代到来之后,由一叠 blocks 堆起来构成的基础模型已经成为了广泛遵循的基础模型设计范式,一个神经网络的宏观架构由width宽度(channel 数)和 depth 深度(block 数)来决定。有没有想过,一个神经网络未必是一叠 blocks 组成的?可能是 2 叠,4 叠,或者…16 叠?

介绍一下我们最新的工作“Reversible Column Networks”,将解耦学习(disentangled feature learning)的思想引入模型设计中,提出以 reversible column 为单元来传递信息,既保证特征解耦,同时信息在网络中的传递不受到损失。整个网络结构包括了多个子网络(我们称为 column),column 间加入可逆的连接,通过将输入反复接入 column,逐渐分离 low-level 的纹理细节和 semantic 语义信息。这样做的好处在于,既能够保证在预训练中保持高精度,又保证了 low-level 的信息不丢失以在下游任务(detection,segmentation)中能够达到更好效果。为了验证这套设计模式在大模型大数据下的表现,我们在 RevCol 上做了一个2B 参数的纯 CNN 超大模型,且只使用了 3x3 的卷积核。在 ImageNet-1K 上达到了 90% 的 Top-1 Accuracy,下游的检测和分割任务上双双达到 60+的水平,COCO AP box 63.8%,ADE 20k mIoU 61.0%。此外,RevCol 架构依然遵循了可逆神经网络的设计范式,也就继承了可逆网络天然的节省显存的优势,文中的大部分实验均可在 2080ti 上完成。而节省显存这件事,对于大模型训练无疑是重要的。

08ccbf7416fca2d6a7123b49973fa4a0.png

arxiv: https://arxiv.org/pdf/2212.11696.pdf

github: https://github.com/megvii-research/RevCol

(旷视研究院Foundation Model Group,转载请注明出处)

c89ee160988c29d30790c4f750fff874.png

一、背景

早在 CNN 时代,我们就发现单纯把像 ResNet 用在下游任务上难以发挥优势,但使用 HRNet,FPN 这样的多尺度融合的方式后,就能够取得较好的效果。然而单纯把 HRNet 和 FPN 放在上游分类任务中,效果又不如 ResNet 这样的直筒网络。那么多尺度融合的网络和直筒网络到底是哪里做对了,又哪里不足呢?

要回答这个问题,我们可以借助 Tishby 提出的 Information Bottleneck principle 来审视这些结构。Information Bottleneck 讲,网络向前传播的时候,会逐步丢弃(压缩)和任务无关的信息,而保留对任务有帮助的信息。对于分类任务上的网络来说,靠近输入的浅层 feature 蕴含了大量和分类无关的 low-level 信息,而靠近输出的深层主要是 semantic  语义信息。单从分类上看,这样似乎是合理的,但如果把在分类任务上预训练得到的网络再用于下游,预训练阶段的信息损失就会影响到下游任务上的效果。

所以,最好的 backbone 应该是具备把 task-relevant 的语义信息 decoupled 出来放到一些dimensions中,且在网络中保留尽可能多的输入信息的。这话 Bengio 在2012年Disentangling factors of variation 的文章中也提到了,原句是: we require a means of feature extraction that disentangles these factors in the data rather than simply learn to represent some of these factors at the expense of those that are lost in the filter pooling operation.

9a912b1eb63892b0c79a3d6b80413dfb.png

我们的 RevCol 通过在结构上精妙的设计,在 e2e 的训练 pipline 下实现了 disentangle 的目标。如 Figure 1 所示,一般的直筒网络(single column)的信息传递方式是,越靠近 input的部分的信息越偏向 low-level,越靠近 loss 的位置越 semantic。而 RevCol 采用了 multi-input的设计,每个 column 的起始位置都是 low-level 的信息。随着column的 iteration,在column 的最末端,feature 中的语义信息就逐渐被提取了出来。在 column 之间采用了 Reversible 的连接设计,也就是说后面 column 可以倒推出前面 column 的信息。这样从最后一个 column 可以一路倒推回第一个 column,以此保证信息在 column 间传递的时候是无损的。同时在 column 的最末端加入中间监督,显式地约束每个 column 的输出信息的表征,以此来保证语义信息能随着 column iteration 被 decouple 出来。

二、方法

84898d7c68bf43013b427b1c2e2b019f.png

RevCol 结构中包含了很多 subnet,我们称作为 column。依照 column iteration 的形式把multi-level reversible 单元排列起来就构成了网络的宏观架构。

2.1        Multi-Level Reversible Unit

如果各位读者了解过 RevNet(Reversible 网络的开山之作),应该对立面提到的双路交叉的网络结构印象深刻。这里我们再赘述一遍:如 Figure 2a 所示,RevNet 先将输入划分为    和  (这俩可以保持一致),后面的每一个 Reversible block 的输入,都来自于前面两个block 的输出。

(1)

Equation 1 是 RevNet 的前传和反推的公式,其中  表示当前的 block,  表示上一个block,  表示上上一个。当前 block 的输入由两个部分构成,第一部分来源于上一个block 的输出  ,过一个 Non-linear Function F(·)(可以是 Conv/BN/Relu/Residual Block)。第二部分是上两个 block 的输出  经过一个 Reversible Operation,比如channel-wise scaling。然后这两个部分加和。这样的约束保证了在倒推的时候  可以通过重新把更先前一步已推得的  重新输入 Ft(·)中计算出结果,然后把 forward 的式子反过来。加变减,乘变除,反推出来。

RevNet 的 Equation 是有一些天然缺陷的。在加号两侧的 tensor 必须保持一样的 shape,也就意味着 、、, 这一系列输出不能呈现 hierarchical 的特性。这也是为何 RevNet 无法做到从输入到输出全程 Reversible,而是退而求其次选择了在每个 stagesame resolution)内部 Reversible,down-sample 的时候依然会丢失信息。通过这种方式我们没有办法把多个 hierarchical 的 column 穿起来,所以,得改。

我们把 RevNet 的 Equation 推导到了更加 generalized 的形式。下式 Equation 2 中,我们在 RevNet 的第一部分 Ft(·)的输入中增加了更多的 x,第二部分维持不变。也就是说现在的式子不再是依赖于上两个 block 的输出,而是上 m 个,其中有一个放在了第二部分,剩下的都放在第一部分。倒推的时候只要拿到第一部分的所有 m-1 个输入 、 就能够计算出第二部分的输入  了。而这 m-1 个输入在更早的时刻就已经被推得了。

(2)

 ,, 

 , 

【把这个公式装进column】如果把一个网络的所有 feature 看做一个长序列,我们把每 m 个feature 划分为一个 group, 、、 那么只要拿到了一个 group,就可以逐一、步进地推出下一个 group 中的全部 feature。同理,倒推的时候也是由当前 group 逐一倒推上个 group 的值。也可以理解为滑窗。这样,如果把每个分组的 feature 抽成一个 column,然后 column by column 的 forward/inverse,就构成了RevCol 的基本结构。

【Equation 2 带来了哪些 benefits?】首先,RevNet 中对于 shape 的限制被缓解了,我们只要保证 和 的 shape 一致就行了,  到  因为 Ft(·)能把它们调整到和  一致。所以 hierarchical 的所有优点我们依然保留着。另外,这个结构和任何单 column 的网络都可以复合,网络中的 features 刚好对应到一个 group 的 features 上。这些 benefits 对把网络做大做强都是至关重要的。

2.2 基本结构

【宏观结构】

c0e7972ef9eefc603aa5efaf709ddfae.png

我们在真正实现 RevNet 的时候,对上面的 Multi-level Reversible 做了一些精简。Ft(·) 的输入是 m-1 个,我们把它精简为 2 个了,增加更多的输入收益有限。参照 Figure 2 (c)来看,一个输入来源于同 column 的上一个 Level(红色线),另一个输入来源于先前 column 中的下一个 Level(蓝色线)。两个输入一个代表了高层次语义信息,一个代表的是低层次纹理信息,够了。

【微观结构】

502d6fe19d7952aea9017fa9a2275613.png

每一个 Level 中,先用一个 Fusion 单元(Figure 5 c),把不同 shape 的输入调整到同一个 shape,后面再经过一堆 ConvNeXt Blocks,得到输出,这些就是公式中的 Ft(·),然后再和 Reversible operation 的输入加和,得到最终结果。值得注意的是,我们把原本的 ConvNeXt block 中 7x7 的 kernel size 改成了 3x3,大 kernel 的收益在 revcol 上有限,但小 kernel 是真的快!

2.3 中间监督

我们还做了个 plug-in 的中间监督方法,在不改变 e2e 训练的 pipline 下,上下游能额外带来 1 个点以上的收益。我们发现,由于 Reversible unit 的后加和设计,在每个 column 的最后,都用 Reversible operation 的方式一路加到了最后一个 column 然后接 loss,这样就会让前面 column 的底部离 loss 很近。离 loss 近,就意味着这个位置的 feature 中包含的主要是语义信息。那么第一个 column 本身是会丢信息的,第一个 column 如果包含了太多语义信息丢失了太多纹理信息,那后面再 Reversible,收益也很小了。它不能坍缩掉。

所以我们在 column 底部加了一个分类 head,和一个 feature 重建的 head,分别接 CE、BCE Loss,然后随着 column 加深,逐步调节这俩 loss 的比值,最后一个 column 重建的loss为0,分类的 loss 占比为 100%。

(3)

The total loss L in is the summation of all compound loss:

2.4 模型设计

在 RevCol 这套架构下,一个模型有三个维度了:channel 数(width),单个 column block 数(depth),还有 column 数。我们在设计模型的时候,发现增加 column 数的收益几乎等同于同时增加 width 和 depth,所以做了个简单粗暴的 scale up rule:一个 small(8 column)就是俩 tiny(2x4 column),一个 base(16 column)就是四个 tiny(4* 4column)。只不过这里在 base 上为了和竞品对齐计算量稍微做了点调整。

dc05bf3bd1a4810f9b7e1d712b271055.png

哦对了,RevNet 还有个人见人爱的特性,也就是其他 Reversible papers 里纷纷宣传的节省显存。RevCol-T/S/B 用的单 column 计算量几乎是一样的,增加 column 后显存中只增加了param 的存储,所以这仨尺寸的模型基本上占用的显存是一样的。他们都可以在 RTX 2080ti (11G)中训练。在 Huge(2B 参数)上,我们也开启了 Reversible 的计算方式,以此提高训练效率。如果无法节省显存,Huge 的训练代价恐怕要增加很多倍。

三、实验结果

【ImageNet Classification】

09754ad8fcdf850ddfd9f1ddae80643c.png

除了 2B 参数的模型以外,我们还收集了 168Million 的私有数据集(Megdata-168M),weakly-label 的标签。用来预训练。XL模型(800M param),在 22k 下能达到 88.2,经过Megdata-168M,的训练后能够涨到 89.4。Huge(2.1 B param)224 pretrain,640x640 Finetune,能够达到 90.0% Top-1 Accuracy。这个模型的训练开销:预训练总共 1600 个ImageNet Epochs,训练一次使用 80 块 A100 需要 14 天。

【COCO Object Detection】

48cae0a848a80ebe50684d8224d3dcf6.png

【ADE Semantic Segmentation】

3b2fcc43ec48e24446c3963370dde3a0.png

在 COCO 上,使用 DINO 的框架,经过 Object 365 进一步 Finetune 之后,RevCol-H 能够达到 63.8 的 AP box。在 ADE 20k,使用 Mask2Former 框架,mIoU 能够达到 61%。

【Foundation Models】

f2997fa8c603a91c23f9f9e23a451778.png

我们列举了各家的 Foundation Models 并且做了个对比。RevCol-H 作为一个单模态模型(Megdata-168M 数据集只包含图片没有 language)且标签用的是 semi-labeled 方式,我们没有使用 Mask Image Modeling 预训练,我们还是个 CNN。最终的上下游任务都能够达到和其他单模态多模态大模型 comparable 的结果,比如多模态模型 BEiT3,多模态模型Florence,单模态超大模型外加 MIM 预训练 setting 下的 Swinv2-G。

四、结论和展望

我们做 RevCol 的第一个版本的目标,是在纯视觉任务下验证了它 scale up 的 capability。但CV 大模型不应只是局限于分类检测等等这些任务下,尤其是 ChatGPT 现象级爆火之后,CV 的大模型未来在哪里更加值得深思。我们能看到的潜在未来,包括视频理解,多模态模型,生成模型,自动驾驶等等。我们坚信 RevCol 这样一套宏观架构是能够普遍使用的,CV 的未来在哪里,RevCol 就在哪里,顶峰相见吧各位。

a6a6e7484ba3e2eb534075bcf71deb4f.gif

猜你喜欢

转载自blog.csdn.net/Megvii_tech/article/details/129095758