《DARTS+:Improved Differentiable Architecture Search with Early Stopping》论文笔记

1 Overview

Introduction: When NAS evolved to use the differentiable network structure DARTS, it has greatly reduced the time and memory consumption of network search. However, as the training epoch increases, the performance of DARTS actually collapses, not a continuous improvement. After the training is completed, it is found that the finally searched network structure tends to be more skip connections instead of conventional convolution/pooling and other operations, which leads to the morbidity of network expression ability. After carefully analyzing the process steps of DARTS optimization, the article found that two parts of the parameters were specifically optimized: structural parameter α \alphaα and network parameterswww , the optimization process of these two parameters is actually carried out alternately. In fact, these two processes not only include cooperation but also competition. Some previous work (PDARTS) introduced some regularization methods to prevent the occurrence of "over-fitting" phenomenon. In addition to regularization constraints, it is actually possible to alleviate this problem by early stop. So the core of this article is how to set up this early stopping discrimination mechanism. The method of the article has a test error rate of 2.32% on CIFAR-10, 14.87% on CIFAR-100 dataset, and 23.7% on ImageNet dataset.

Compare the final network structure of the two methods mentioned in the article with training to convergence. The results are shown in Figure a below:
Insert picture description here
you can see that when the training is converged, there will be more skip connections in the network. , Which leads to the lack of network expression ability.

2. Colapse analysis in DARTS

2.1 DARTS algorithm flow

The essential task of the DARTS algorithm is to optimize the structural parameters α \alphaα and network parameterswww , these two parts are optimized alternately, and the optimization process can be described as:
min ⁡ α L val (w ∗ (α), α) \min_{\alpha}L_{val}(w^{*}( \alpha),\alpha)ameLval(w (α),α)
s . t .   w ∗ = arg min ⁡ w L t r a i n ( w , α ) s.t.\ w^{*}=\argmin_{w}L_{train}(w,\alpha) s.t. w=wargm i nLtrain(w,α )After
waiting for the convergence of the above process, the final structure parameterα \alphawill be obtainedα , then the generation of the network can be described as the following two steps:

  • 1) Get the probability of the network structure o (i, j) = arg max ⁡ o ∈ O, o ≠ zeropo (i, j) o^{(i,j)}=\argmax_{o\in\mathcal{O} ,o\ne zero}p_o^{(i,j)}O(i,j)=o O , o=zeroargmaxpO(i,j)
  • 2) For each node, select the two branches with the highest probability of connecting with it max ⁡ o ∈ O, o ≠ zeropo (i, j) \max_{o\in\mathcal{O},o\neq zero}p_o ^{(i,j)}maxo O , o=zeropO(i,j)

2.2 The collapse problem in DARTS

From the appearance point of view, as the iterative epoch of the training process increases, the network will tend to choose skip connection, which can be seen in Figure 1 b and Figure 2 c.
Insert picture description here
Then the relationship between the increase of skip connection and the expression of network performance can be shown in Figure 2 a, and it can be seen that there is an opposite relationship between them in the later stage. Moreover, from the searched network structure, too many skip connections will cause the network expression ability to decline, and thus the network performance is reduced. And this is a common phenomenon, not to say that this situation can be avoided after changing the data set. So what is the underlying reason for the collapse of such a model?

The article pointed out that this is caused by the structural parameter α \alphaα and network parameterswwCaused by confrontation between w (although it was a cooperative relationship at the beginning). After the article analyzes this, the specific reasons can be summarized as follows:

  • 1) Network parameters www is relative to the structural parameterα \alphaα has more advantages in the competition, due to the network parameterwwRelatively speaking, w has more learnable parameters, which makes it relatively less sensitive to the final loss in the alternate training process, which leads to the structural parameterα \alphaα is gradually at a disadvantage in the competition process, as shown in Figure 2 a;
  • 2) The front and back relationship between the network cell and the arrangement. The cell at the input end of the entire network can access the latest information, while the information contacted by the cell at the back end of the entire network contains more noise. In order to obtain better resolution, This makes the last cell connect to the previous cell using skip connection. The following figure shows the network structure at different stages (front, middle and back) of the entire network:
    Insert picture description here
    it can be seen that the closer to the back end, the more skip connections are selected. For this, you can indeed choose differently in different network segments. Cell. But if the same cell is selected as in DARTS, skip connection will be passed from the back end to the front end.

2.3 Early stopping mechanism

In the above conclusions, the structure parameter α \alphaα and network parameterswwThe w promotes each other, but they compete with each other later, and it is also accompanied by the increase of skip connections in the network. Therefore, there are some connections between skip connections and iterative epochs and the collapse of the network. Therefore, the article puts forward the first early stopping criterion for this law:

If there are two skip connections in a normal cell in the search network, stop training.

In order to more accurately determine the epoch that terminates early, the article also proposes another stopping criterion:

When the network structure parameter α \alphaα remains stable for a certain period of time (a certain number of epochs), then it can be stopped (refer to Figure 2 a).

In addition to the above rules, three regularization mechanisms are proposed in PADRTS:

  • 1) Only train the network for 25 epochs instead of 50;
  • 2) Add dropout after skip connection;
  • 3) Think of setting the number of skip connections in the cell;

In PC-DARTS, the partial-channel connection is used to shorten the search time, which also increases the number of epochs required for network convergence.

3. Experimental results

CIFAR10 and CIFAR100 data set:
Insert picture description here
ImageNet data set:
Insert picture description here

Guess you like

Origin blog.csdn.net/m_buddy/article/details/110920805