How should each loss weight be designed in multi-task learning?

Source: (22 private letters/80 messages) How should each loss weight be designed in multi-task learning? - Zhihu (zhihu.com)

        Multiple losses are common in deep learning, for example:

  • Target detection : Taking YOLO as an example, its loss function is generally composed of several parts, including center coordinate error loss, width and height coordinate error loss , width and height coordinate error loss , and confidence error loss. This also involves a foreground and background issue, that is, whether there is a target in the current grid, the calculation method of the loss function will be different, so it is necessary to control the ratio of the two to avoid the gradient being dominated by a certain task.
  • Semantic Segmentation : In particular, for tasks such as medical imaging where the foreground and background differences are very small, many methods usually introduce a deep supervision mechanism, so that it is not just a two-level loss, and almost every stage will lead to a branch come out. In such a situation, you can't arrange and combine to make alchemy, can you?

        The multi-loss problem involves the two challenges of multi-task learning Network Architecture (how to share) and Loss Function (how to balance) .

(1) Unify each loss to the same order of magnitude. The reason behind it is that the scales of different loss functions of different tasks are very different. Therefore, it is necessary to consider using weights to unify the scale of each loss function. In general, the gradient size is different in the convergence process of different tasks, and the sensitivity to different learning rates is also different. Unifying each loss to the same order of magnitude can prevent the loss with a small gradient from being taken away by the loss with a large gradient, so that the learned feature has better generalization ability.

For a relatively new solution to this problem, you can refer to a work of cvpr2018 "Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics" . This article proposes to use the uncertainty of the same variance and use the uncertainty as noise. to train.

(2) Network Architecture generally includes: Hard-parameter sharing and Soft-parameter-sharing . Hard-parameter sharing has the same set of feature sharing layers , which greatly reduces the risk of overfitting ; Soft-parameter sharing each task has its own feature parameters and constraints between each subtask. As for the problem of "the order of magnitude of the gradient of different network parts", although the gradient value is the same, the performance of different losses on the task is also different, so it is still necessary to find a suitable balance point for different losses.

For the relationship between Network Architecture and Loss Function, after a lot of experiments and experience summarization, it is found that a good Network Design can effectively increase the generalization ability of featrue than a good Loss Design , and a good Dataset is better than a good Network Design. It can effectively increase the generalization ability of the feature.

Finally, show an example of our training Face Reconstruction and Dense Alignment using Multi-task learning . Briefly explain: There are 4 tasks in this model, and there are 4w key points in the dense alignment part, and you can blink and open your mouth.

https://github.com/Star-Clouds/FRDA​github.com/Star-Clouds/FRDA

multitasking learning

definition

Multi-Task Learning (MTL) is an active research field in machine learning. It is a learning paradigm that aims to jointly learn multiple related tasks to improve their generalization performance by exploiting the common knowledge among them. In recent years, many researchers have successfully applied MTL to different fields such as computer vision, natural language processing, reinforcement learning, and recommender systems. The current research on MTL mainly focuses on two perspectives, network architecture design and loss weighting strategy.

Network architecture design

In the design of network architectures, the simplest and most popular approach is hard parameter sharing (HPS, LibMTL.architecture.HPS), as shown in Figure 1, where the encoder is shared among all tasks, each task has its own encoders and specific decoders. Since most of the parameters are shared between tasks, this architecture can easily lead to negative sharing (i.e., one loses all) when the correlation between tasks is not large enough. To better handle the relationship between tasks, different MTL architectures have been proposed.

figure 1

In the figure above, the single-input problem is on the left and the multi-input problem is on the right. Take the hard parameter sharing mode as an example. Currently LibMTL already supports a variety of the most advanced architectures, please refer to the LibMTL.architecture branch for details.

loss weighting strategy

Balancing multiple losses corresponding to multiple tasks is another way to handle task relations, since shared parameters are updated by all task losses. Therefore, different methods have been proposed to balance the loss or the gradient.

LibMTL currently supports a variety of state-of-the-art weighting strategies, please refer to the LibMTL.weighting branch for details.

For example, some gradient balancing based methods such as MGDA need to compute the gradients of each task first, and then compute the aggregated gradients in various ways. In order to reduce the computational cost, it can use the gradient of the post-encoder representation ( rep-grad for short ) to approximate the gradient of the shared parameters ( param-grad for short ).

The PyTorch implementation of rep-grad is shown in Figure 2. We need to separate the computation graph into two parts through the detach operation. LibMTLInternally, these two situations have been unified into one training framework, so we only need to set the command line parameter rep_grad correctly . Also, the parameter rep_grad does not conflict with multi_input.

figure 2

The example diagram above clearly illustrates how to compute the gradient of the representation.

LibMTL

LibMTL [1] is an open source library for multi-task learning built on PyTorch.

Guess you like

Origin blog.csdn.net/dou3516/article/details/130507785