CrossKD principle and code analysis

paper:CrossKD: Cross-Head Knowledge Distillation for Dense Object Detection

official implementation: https://github.com/jbwang1997/CrossKD

foreword 

Distillation can be divided into two types: predictive distillation prediction mimicking and feature distillation feature imitation. In 2015, Geoffrey Hinton proposed the pioneering work of KD knowledge distillation KD: Distilling the Knowledge in a Neural Network The principle and code analysis belong to predictive simulation, and FitNets: Hints for Thin Deep Nets principles and code analysis are typical feature simulations. However, for a long time, it has been found that predictive simulation is more inefficient than feature simulation. LD for Dense Object Detection (CVPR 2022) principle and code analysis show that predictive simulation has the ability to transfer specific task knowledge, which is beneficial for students to perform predictive simulation at the same time. and feature simulation. This prompted the authors to further explore and improve the predictive simulations.

The innovation of this article

In predictive simulation, the prediction of the student model needs to simulate the prediction of the GT and the teacher model at the same time, but the prediction of the teacher model is often very different from the GT, and the student model has experienced a contradictory learning process during the distillation process. The author believes that This is the main reason that prevents predictive models from achieving higher performance.

In order to alleviate the problem of learning target conflict, this paper proposes a new distillation method CrossKD, which sends the intermediate features of the student detection head to the teacher's detection head, and distills the obtained prediction results with the teacher's original prediction results. This method has Two benefits, first KD loss does not affect the weight update of the student detection head, avoiding the conflict between the original detection loss and KD loss. In addition, since the prediction of the intersection head and the prediction of the teacher share some of the detection heads of the teacher, the predictions of the two are relatively consistent, which alleviates the prediction difference between the student-teacher and improves the training stability of the prediction simulation.

Prediction simulation, feature simulation, and the cross kd proposed in this paper are shown in (a)(b)(c) of Figure 1

method introduction 

The overall architecture of CrossKD is shown in Figure 3

Given a dense detector such as RetinaNet, each detection head usually consists of a series of convolutions, denoted as \(\left \{ C_{i} \right \} \). For simplicity, we assume that each detection head has a total of \(n\) convolutional layers, (such as n=5 in RetinaNet, including 4 hidden layers and 1 prediction layer). We use \(f_{i},i\in\left \{ 1,2,...,n-1 \right \} \) to represent the output feature map of \(C_{i}\),\( f_{0}\) represents the output feature map of \(C_{1}\). Prediction\(p\) is the output of the last convolutional layer\(C_{n}\), the final prediction results of teachers and students can be expressed as\(p^{t},p^{s}\) .

CrossKD sends the intermediate features\(f_{i}^{s},i\in\left \{ 1,2,...,n-1 \right \} \) of the student detection head to\(C^{ t}_{i+1}\), namely the \((i+1)\)th convolutional layer of the teacher detection head, and get the prediction \(\hat{p}^{s}\) of the cross head. Unlike the previous method, we do not calculate the KD loss between \(p^{s}\) and \(p^{t}\), but calculate \(\hat{p}^{s}\) and \(p^{t}\) KD loss between, as follows

Where \(\mathcal{S}(\cdot)\) and \(|\mathcal{S}|\) are region selection principle and normalization factor respectively. The author of this article did not involve the complex \(\mathcal{S}(\cdot)\), the classification branch\(\mathcal{S}(\cdot)\) is a constant value 1, the regression branch foreground area\(\mathcal{S }(\cdot)\) is 1 and background area\(\mathcal{S}(\cdot)\) is 0.

Experimental results

The first is some ablation experiments, the teacher network uses ResNet-50+GFL, and the student network is ResNet-18.

Positions to apply CrossKD.

As mentioned above, the output of the \(i\)th convolutional layer of the student detection head is sent to the teacher network. Here, the author compares the impact of different \(i\) values ​​on the final result. When \(i=0\ ) means that the output features of FPN are directly sent to the head of the teacher network, and the specific results are as follows

It can be seen that when \(i=3\), the final accuracy of the model is the highest, so the default configuration \(i=3\) is used in subsequent experiments.

CrossKD v.s. Feature Imitation.

The author compared CrossKD and the SOTA method PKD of feature distillation. For the sake of fairness, PKD is performed at the same position as CrossKD, including the neck of \(i=0\) and the head of \(i=3\). The results are as follows

It can be seen that no matter where PKD is, the effect is not as good as CrossKD.

CrossKD for Lightweight Detectors.

The author puts the results on the CrossKD lightweight detector as follows

 

It can be seen that the teacher network is ResNet-101+GFL, the student network is ResNet-50, ResNet-34, ResNet-18, and CrossKD can significantly improve the accuracy.

Comparison with SOTA KD Methods

The comparison with other SOTA distillation methods for target detection is shown in the table below. It can be seen that CrossKD is superior to all existing methods.

code analysis

The official implementation is based on mmdetection, and crosskd is used in atss, fcos, gfl, retinanet. Taking atss as an example, the code is in mmdet/models/detectors/crosskd_atss.py, and the loss part code is as follows. First, the original input batch_inputs passes through the backbone and neck of the teacher and the student respectively, self.teacher_extract_feat is the backbone and neck of the teacher network, and self.extract_feat is the backbone and neck of the student network.

 def loss(self, batch_inputs: Tensor,
             batch_data_samples: SampleList) -> Union[dict, list]:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            batch_inputs (Tensor): Input images of shape (N, C, H, W).
                These should usually be mean centered and std scaled.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.

        Returns:
            dict: A dictionary of loss components.
        """
        tea_x = self.teacher.extract_feat(batch_inputs)
        tea_cls_scores, tea_bbox_preds, tea_centernesses, tea_cls_hold, tea_reg_hold = \
            multi_apply(self.forward_hkd_single, 
                        tea_x,
                        self.teacher.bbox_head.scales, 
                        module=self.teacher)
            
        stu_x = self.extract_feat(batch_inputs)
        stu_cls_scores, stu_bbox_preds, stu_centernesses, stu_cls_hold, stu_reg_hold = \
            multi_apply(self.forward_hkd_single, 
                        stu_x,
                        self.bbox_head.scales, 
                        module=self)
            
        reused_cls_scores, reused_bbox_preds, reused_centernesses = multi_apply(
            self.reuse_teacher_head, 
            tea_cls_hold, 
            tea_reg_hold, 
            stu_cls_hold,
            stu_reg_hold, 
            self.teacher.bbox_head.scales)


        outputs = unpack_gt_instances(batch_data_samples)
        (batch_gt_instances, batch_gt_instances_ignore,
         batch_img_metas) = outputs
        losses = self.loss_by_feat(tea_cls_scores, 
                                   tea_bbox_preds,
                                   tea_centernesses,
                                   tea_x,
                                   stu_cls_scores,
                                   stu_bbox_preds,
                                   stu_centernesses,
                                   stu_x,
                                   reused_cls_scores,
                                   reused_bbox_preds,
                                   reused_centernesses,
                                   batch_gt_instances,
                                   batch_img_metas, 
                                   batch_gt_instances_ignore)
        return losses

The obtained neck output features tea_x and stu_x , and then enter the function self.forward_hkd_single respectively , as follows

    def forward_hkd_single(self, x, scale, module):
        cls_feat, reg_feat = x, x
        cls_feat_hold, reg_feat_hold = x, x
        for i, cls_conv in enumerate(module.bbox_head.cls_convs):
            cls_feat = cls_conv(cls_feat, activate=False)
            if i + 1 == self.reused_teacher_head_idx:
                cls_feat_hold = cls_feat
            cls_feat = cls_conv.activate(cls_feat)
        for i, reg_conv in enumerate(module.bbox_head.reg_convs):
            reg_feat = reg_conv(reg_feat, activate=False)
            if i + 1 == self.reused_teacher_head_idx:
                reg_feat_hold = reg_feat
            reg_feat = reg_conv.activate(reg_feat)
        cls_score = module.bbox_head.atss_cls(cls_feat)
        bbox_pred = scale(module.bbox_head.atss_reg(reg_feat)).float()
        centerness = module.bbox_head.atss_centerness(reg_feat)
        return cls_score, bbox_pred, centerness, cls_feat_hold, reg_feat_hold

Among them, through the head of the teacher and the student respectively, including the cls branch and the reg branch, self.reused_teacher_head_idx is the index of the feature of the teacher's detection head in the student head, and the feature of this position is saved and sent to the teacher's head, that is, the function reuse_teacher_head .

    def reuse_teacher_head(self, tea_cls_feat, tea_reg_feat, stu_cls_feat,
                           stu_reg_feat, scale):
        reused_cls_feat = self.align_scale(stu_cls_feat, tea_cls_feat)
        reused_reg_feat = self.align_scale(stu_reg_feat, tea_reg_feat)
        if self.reused_teacher_head_idx != 0:
            reused_cls_feat = F.relu(reused_cls_feat)
            reused_reg_feat = F.relu(reused_reg_feat)

        module = self.teacher.bbox_head
        for i in range(self.reused_teacher_head_idx, module.stacked_convs):
            reused_cls_feat = module.cls_convs[i](reused_cls_feat)
            reused_reg_feat = module.reg_convs[i](reused_reg_feat)
        reused_cls_score = module.atss_cls(reused_cls_feat)
        reused_bbox_pred = scale(module.atss_reg(reused_reg_feat)).float()
        reused_centerness = module.atss_centerness(reused_reg_feat)
        return reused_cls_score, reused_bbox_pred, reused_centerness

Note that there is an align_scale step here, which is not mentioned in the paper, that is, after subtracting the mean from the characteristics of the student’s head and dividing it by the variance, multiply it by the variance of the corresponding position feature of the teacher’s head and then add the mean of the teacher’s characteristics, as follows

    def align_scale(self, stu_feat, tea_feat):
        N, C, H, W = stu_feat.size()
        # normalize student feature
        stu_feat = stu_feat.permute(1, 0, 2, 3).reshape(C, -1)
        stu_mean = stu_feat.mean(dim=-1, keepdim=True)
        stu_std = stu_feat.std(dim=-1, keepdim=True)
        stu_feat = (stu_feat - stu_mean) / (stu_std + 1e-6)
        #
        tea_feat = tea_feat.permute(1, 0, 2, 3).reshape(C, -1)
        tea_mean = tea_feat.mean(dim=-1, keepdim=True)
        tea_std = tea_feat.std(dim=-1, keepdim=True)
        stu_feat = stu_feat * tea_std + tea_mean
        return stu_feat.reshape(C, N, H, W).permute(1, 0, 2, 3)

Guess you like

Origin blog.csdn.net/ooooocj/article/details/131628652