《DARTS+:Improved Differentiable Architecture Search with Early Stopping》论文笔记

1. 概述

导读:NAS演化到使用可微网络结构DARTS的时候,已经将网络搜索的时间与显存消耗大大降低。但是随着训练epoch的增加DARTS的性能实际是collapse的,并不是持续提升。训练完成之后,发现最后搜索出来的网络结构趋向于较多的skip connection,而不是常规的卷积/池化等操作,这就导致了网络表达能力的病态。文章经过仔细分析DARTS优化的流程步骤,发现其中具体优化两部分参数:结构参数 α \alpha α和网络参数 w w w,这两个参数的优化过程其实是交替进行的,实际上这两个过程是不仅包含合作还包含了竞争关系的。之前的一些工作(PDARTS)中引入了一些正则化的方式来防止“过拟合”现象的发生,除了正则化的约束,其实还可以通过early stop的方式缓解这个问题。那么这篇文章的核心便是怎么去设置这个早停的判别机制了。文章的方法在CIFAR-10上test错误率为2.32%,CIFAR-100数据集上为14.87%,ImageNet数据集上为23.7%。

将文章提到的训练早停与训练到收敛这两种方式最后的到的网络结构进行比较,其结果见下图a所示:
在这里插入图片描述
可以看到当训练到收敛的时候网络会存在较多skip connection,这就导致了网络表达能力不足。

2. DARTS中colapse分析

2.1 DARTS算法流程

DARTS算法中实质的任务是优化结构参数 α \alpha α和网络参数 w w w,这两个部分是交替进行优化的,其优化过程可以描述为:
min ⁡ α L v a l ( w ∗ ( α ) , α ) \min_{\alpha}L_{val}(w^{*}(\alpha),\alpha) αminLval(w(α),α)
s . t .   w ∗ = arg min ⁡ w L t r a i n ( w , α ) s.t.\ w^{*}=\argmin_{w}L_{train}(w,\alpha) s.t. w=wargminLtrain(w,α)
等待上述的过程收敛之后,便会得到最终的结构参数 α \alpha α ,那么网络的生成可以描述为下面的两步:

  • 1)得到网络结构的概率 o ( i , j ) = arg max ⁡ o ∈ O , o ≠ z e r o p o ( i , j ) o^{(i,j)}=\argmax_{o\in\mathcal{O},o\ne zero}p_o^{(i,j)} o(i,j)=oO,o=zeroargmaxpo(i,j)
  • 2)对于每个节点选择与之连接概率最大概率的两个支路 max ⁡ o ∈ O , o ≠ z e r o p o ( i , j ) \max_{o\in\mathcal{O},o\neq zero}p_o^{(i,j)} maxoO,o=zeropo(i,j)

2.2 DARTS中的collapse问题

从表象上看,随着训练过程迭代epoch的增加网络会更加趋向于选择skip connection,可见图1的b图以及图2的c图。
在这里插入图片描述
那么skip connection的增加与网络性能的表达之间的关系可以见图2的a图所示,可以看到他们之间在后期是呈现相反的关系。而且从搜索得到的网络结构上来说,过多的skip connection会导致网络表达能力下降,从而网络性能降低。而且这是一个普遍现象,并不是说换了数据集之后就能避免这个情况的发生。那么造成这样模型collapse的深层原因是什么呢?

文章指出这是由结构参数 α \alpha α和网络参数 w w w之间对抗造成的(尽管开始的时候是相互合作的关系)。文章对此进行分析之后具体的原因可以归纳为:

  • 1)网络参数 w w w相对于结构参数 α \alpha α在竞争中更加具有优势,由于网络参数 w w w相对来说拥有更多的可学习参数,这就导致了其对交替训练过程中最后损失相对不那么敏感,这就导致结构参数 α \alpha α逐渐在竞争过程中逐渐处于劣势,详见图2的a图所示;
  • 2)网络cell与排列的前后关系,在整个网络输入端的cell能接触到最新的信息,而在整个网络后端的cell接触到的信息是包含了较多噪声的,为了获取更好的分辨能力,这就使得最后的cell会使用skip connection与之前的cell相连。下图中展示的便是整个网络的不同阶段(前中后)的网络结构:
    在这里插入图片描述
    其中够可以看到越往后端靠,选择的skip connection越多,对于这里确实可以在不同的网络段选择不同的cell。但是如果像DARTS中那样选择一样的cell那么skip connection就会从后端往前端进行传递了。

2.3 早停机制

在上文的结论得出在网络搜索的开始阶段结构参数 α \alpha α和网络参数 w w w是相互促进的,但是在后面就开始相互竞争了,也伴随着网络中skip connection的增加,因而skip connection和迭代epoch以及网络的collapse之间是存在一些联系的。因而文章提出了这对这个规律提出了第一点早停准则:

当搜索网络中若是存在一个normal cell中具有2个skip connection就停止训练。

为了更加精确确定提前终止的epoch,文章还提出了更加另外一种停止准则:

在当网络结构参数 α \alpha α在一段时间(一定数量的epoch)保持稳定,那么就可以将其停止(参考图2的a图)。

除了上述的规则之外,在PADRTS中提出3点正则化机制:

  • 1)只对网络训练25个epoch而不是50个;
  • 2)在skip connection后面添加dropout;
  • 3)认为设定cell中skip connection的数量;

而在PC-DARTS中使用了partial-channel连接的方式缩小搜索时间,也使得网络收敛所需的epoch数量增加。

3. 实验结果

CIFAR10 and CIFAR100数据集:
在这里插入图片描述
ImageNet数据集:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/m_buddy/article/details/110920805