目标检测知识蒸馏---以SSD为例【附代码】

在上一篇文章中有讲解以分类网络为例的知识蒸馏【分类网络知识蒸馏】,这篇文章将会针对目标检测网络进行蒸馏。

知识蒸馏是一种不改变网络结构模型压缩方法。这里的压缩需要和量化与剪枝进行区分,并不是严格意义上的压缩。这里将要讲的蒸馏是离线式蒸馏中的逻辑蒸馏【特征部分蒸馏以后会讲】,也是一种常用的方法,他是将已经训练好的teacher model对student model进行蒸馏。

teacher model是一个在精度表现上优良的模型,而student model往往是精度差一些,但推理速度高的模型。如果要采用这种蒸馏方式,需要注意的是两个Model的网络结构需要相似【因此可以将改进前后的model建立这种关系】。而实现部分最最重要的部分是建立蒸馏的Loss函数。


目录

MultiBoxloss

loss参数说明

 loss forward部分

标签匹配函数match

loss计算 

 MultiBoxloss_KD

kd for loc regression

kd for conf regression


在目标检测中主要有两个任务,一个是分类,一个是边界的回归,前者的蒸馏是比较容易的,关键是在后者,这也是蒸馏的一个难点。

我们先来看一下SSD代码中的MultiBoxloss部分详解。

MultiBoxloss

SSD中分类loss采用CrossEntropy,边界loss采用平滑L1。具体公式和网络算法原理参考论文,这里不在多说。

loss参数说明

参数说明:

self.use_gpu:是否采用gpu训练

self.num_classes:训练类的数量【在SSD中num_classes是自己的类数量+背景类】

self.threshold:阈值,默认0.5

self.background_label:背景类标签,默认为0

self.encode_target:target编码

self.use_prior_for_matching:利用先眼眶做匹配

self.do_neg_mining:True,负样本挖掘

self.negpos_ratio:负样本比例,设置为3【正负样本比例为1:3】

self.variance :方差

class MultiBoxLoss(nn.Module):
    """SSD Weighted Loss Function
        Compute Targets:
            1) Produce Confidence Target Indices by matching  ground truth boxes
               with (default) 'priorboxes' that have jaccard index > threshold parameter
               (default threshold: 0.5).
            2) Produce localization target by 'encoding' variance into offsets of ground
               truth boxes and their matched  'priorboxes'.
            3) Hard negative mining to filter the excessive number of negative examples
               that comes with using a large number of default bounding boxes.
               (default negative:positive ratio 3:1)
        Objective Loss:
            L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
            Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
            weighted by α which is set to 1 by cross val.
            Args:
                c: class confidences,
                l: predicted boxes,
                g: ground truth boxes
                N: number of matched default boxes
            See: https://arxiv.org/pdf/1512.02325.pdf for more details.
        """

    def __init__(self, num_classes, overlap_thresh, prior_for_matching,
                 bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
                 use_gpu=True, negatives_for_hard=100.0):
        super(MultiBoxLoss, self).__init__()
        self.use_gpu = use_gpu
        self.num_classes = num_classes
        self.threshold = overlap_thresh
        self.background_label = bkg_label
        self.encode_target = encode_target
        self.use_prior_for_matching = prior_for_matching
        self.do_neg_mining = neg_mining
        self.negpos_ratio = neg_pos
        self.neg_overlap = neg_overlap
        self.negatives_for_hard = negatives_for_hard
        self.variance = Config['variance']

 loss forward部分

predictions:类型为tuple,网络的输出内容,包含:位置预测,分类置信度预测以及prior boxes预测。

predictions[0]的shape为:[batch,8732,4]

predictions[1]的shape为:[batch,8732,num_classes]

predictions[2]的shape为:[8732,4]

注:8732:以输入大小300*300为例,将在6个head部分产生8732个先眼眶.

8732= 38*38*4 + 19*19*6 + 10*10*6 + 5*5*6 + 3*3*6 + 1*1*4

target:包含了标注的数据集真实的boxes坐标以及label信息。是一个列表,列表的长度等于batch的数量,每个列表中的元素shape为[num_objs,5],num_objs表示你当前图像中标注的目标数量,5=boxes信息+label信息。

    def forward(self, predictions, targets):
        """Multibox Loss
                Args:
                    predictions (tuple): A tuple containing loc preds, conf preds,
                    and prior boxes from SSD net.
                        conf shape: torch.size(batch_size,num_priors,num_classes)
                        loc shape: torch.size(batch_size,num_priors,4)
                        priors shape: torch.size(num_priors,4)
                    pred_t (tuple): teacher's predictions
                    targets (tensor): Ground truth boxes and labels for a batch,
                        shape: [batch_size,num_objs,5] (last idx is the label).
                """

        #--------------------------------------------------#
        #   取出预测结果的三个值:回归信息,置信度,先验框
        #--------------------------------------------------#
        loc_data, conf_data, priors = predictions
        

创建两个全零张量用来做先验框和真实框的匹配,这里的num等于batch_size

loc_t = torch.zeros(num, num_priors, 4).type(torch.FloatTensor)
conf_t = torch.zeros(num, num_priors).long()

遍历每个batch,idx是batch的索引,truths是获取到的真实值的boxes信息。labels是获取到的当前图像中是什么类。

truths:tensor([[0.2333, 0.2067, 0.6967, 1.0000]], device='cuda:0')

labels:tensor([0.], device='cuda:0')

        for idx in range(num):
            # 获得真实框与标签
            truths = targets[idx][:, :-1]
            labels = targets[idx][:, -1]

            if(len(truths)==0):
                continue

            # 获得先验框
            defaults = priors
            #--------------------------------------------------#
            #   利用真实框和先验框进行匹配。
            #   如果真实框和先验框的重合度较高,则认为匹配上了。
            #   该先验框用于负责检测出该真实框。
            #--------------------------------------------------#
            match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx)

defaults是先验框,接下来是和真实框进行标签匹配。

标签匹配函数match

        传入参数:threshold,truths[真实boxes],defaults[先验框],variance[方差],labels[真实标签],loc_t[前面创建的全零张量,用来存放匹配后的boxes信息], conf_t[用来存储匹配后的置信度分类信息],idx[当前batch的索引 ]。

1.先计算所有先验框和真实框的的重合程度。

        box_a是就是上面的truths,box_b是先验框【注意先验框中的boxes形式是center_x,center_y,w,h,需要先转成左上角和右下角的形式】。最终就可以计算出IOU。

def jaccard(box_a, box_b):
    #-------------------------------------#
    #   返回的inter的shape为[A,B]
    #   代表每一个真实框和先验框的交矩形
    #-------------------------------------#
    inter = intersect(box_a, box_b)
    #-------------------------------------#
    #   计算先验框和真实框各自的面积
    #-------------------------------------#
    area_a = ((box_a[:, 2]-box_a[:, 0]) *
              (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
    area_b = ((box_b[:, 2]-box_b[:, 0]) *
              (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]

    union = area_a + area_b - inter
    #-------------------------------------#
    #   每一个真实框和先验框的交并比[A,B]
    #-------------------------------------#
    return inter / union

因此得到的overlaps是计算的所有先验框和真实框的iou,shape为[1,8732]。 

overlaps = jaccard(
        truths,
        point_form(priors)
    )

接下来是通过max函数获得这8732个先验框中与真实框匹配度最好的框和索引【就相当于可以把这个匹配的最好的认为是ground truth】。

可以得到iou最高的是0.6904,是第8711号先验框。

best_prior_overlap:tensor([[0.6904]], device='cuda:0')

best_prior_idx:tensor([[8711]], device='cuda:0') 

用于保证每个真实框都有一个先验框与之匹配。

    for j in range(best_prior_idx.size(0)):
        best_truth_idx[best_prior_idx[j]] = j
    best_truth_overlap.index_fill_(0, best_prior_idx, 2)

 将truths扩充成8732.

matches = truths[best_truth_idx]

获取标签 

conf = labels[best_truth_idx] + 1

获取背景类,通过设置的iou阈值进行过滤。

conf[best_truth_overlap < threshold] = 0

进行边界框的编码【其实就是将真实框和先验框进行匹配】。

def encode(matched, priors, variances):
    g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
    g_cxcy /= (variances[0] * priors[:, 2:])
    
    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
    g_wh = torch.log(g_wh) / variances[1]
    return torch.cat([g_cxcy, g_wh], 1)

获得的loc shape为【8732,4】 

loc = encode(matches, priors, variances)

将编码后的loc放入前面定义loc_t中,conf也是如此。

获得正样本。

# 所有conf_t>0的地方,代表内部包含物体
pos = conf_t > 0

此时的pos形式如下,shape为【batch,8732】: 

tensor([[False, False, False,  ..., False, False,  True],
        [False, False, False,  ..., False, False, False]], device='cuda:0')

求和得到每个图像内有多少正样本。这就可以计算出在所有的batch中的所有batch*8732个框中有多少框内包含目标。

num_pos = pos.sum(dim=1, keepdim=True)

loss计算 

取出所有的正样本计算loss

获得所有正样本的idx,返回形式是Truth or False.

pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)

通过索引在loc_data[预测的位置]中选择出正样本的loc_p【也就是预测目标的loc】。 

loc_p = loc_data[pos_idx].view(-1, 4)

通过正样本的索引在loc_t【groud truth】中进行筛选获得正样本的loc_t。

loc_t = loc_t[pos_idx].view(-1, 4)

计算边界回归loss:

直接调用smooth_l1_loss计算loss【loc_p是预测值,loc_t是真实值】

 loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)

分类loss:

获得网络预测的conf,进行reshape,这就获得了所有batch中预测框内的conf,shape为【batch*8732,num_classes】。

batch_conf = conf_data.view(-1, self.num_classes)

 conf_p是预测值【筛选后具有正样本的】,

# 这个地方是在寻找难分类的先验框
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
        loss_c = loss_c.view(num, -1)

        # 难分类的先验框不把正样本考虑进去,只考虑难分类的负样本
        loss_c[pos] = 0 
        #--------------------------------------------------#
        #   loss_idx    (num, num_priors)
        #   idx_rank    (num, num_priors)
        #--------------------------------------------------#
        _, loss_idx = loss_c.sort(1, descending=True)
        _, idx_rank = loss_idx.sort(1)
        #--------------------------------------------------#
        #   求和得到每一个图片内部有多少正样本
        #   num_pos     (num, )
        #   neg         (num, num_priors)
        #--------------------------------------------------#
        num_pos = pos.long().sum(1, keepdim=True)
        # 限制负样本数量
        num_neg = torch.clamp(self.negpos_ratio * num_pos, max = pos.size(1) - 1)
        num_neg[num_neg.eq(0)] =  self.negatives_for_hard
        neg = idx_rank < num_neg.expand_as(idx_rank)

        #--------------------------------------------------#
        #   求和得到每一个图片内部有多少正样本
        #   pos_idx   (num, num_priors, num_classes)
        #   neg_idx   (num, num_priors, num_classes)
        #--------------------------------------------------#
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)

        # 选取出用于训练的正样本与负样本,计算loss
        conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
        targets_weighted = conf_t[(pos + neg).gt(0)]
        loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)

最后总的Loss为:

loss:8.0996


 MultiBoxloss_KD

在原来的loss基础上加入了soft-target loss部分。

class MultiBoxLoss_KD(nn.Module):

    def __init__(self, num_classes, overlap_thresh, prior_for_matching,
                 bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
                 use_gpu=True, negatives_for_hard=100.0,neg_w=1.5, pos_w=1.0, Temp=1., reg_m=0.):
        super(MultiBoxLoss_KD, self).__init__()
        self.use_gpu = use_gpu
        self.num_classes = num_classes  # 21
        self.threshold = overlap_thresh  # 0.5
        self.background_label = bkg_label  # 0
        self.encode_target = encode_target  # False
        self.use_prior_for_matching = prior_for_matching  # True
        self.do_neg_mining = neg_mining  # True
        self.negpos_ratio = neg_pos  # 3
        self.neg_overlap = neg_overlap  # 0.5
        self.variance = Config['variance']
        self.negatives_for_hard = negatives_for_hard

        # soft-target loss
        self.neg_w = neg_w # 负样本(背景)权重
        self.pos_w = pos_w # 正样本权重
        self.Temp = Temp # 温度
        self.reg_m = reg_m

在forward部分传入参数为predictions[student的输出,pred_t为teacher的输出,targets是真实值] 

def forward(self, predictions, pred_t, targets):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            and prior boxes from SSD net.
                conf shape: torch.size(batch_size,num_priors,num_classes)
                loc shape: torch.size(batch_size,num_priors,4)
                priors shape: torch.size(num_priors,4)
            pred_t (tuple): teacher's predictions
            targets (tensor): Ground truth boxes and labels for a batch,
                shape: [batch_size,num_objs,5] (last idx is the label).
        """

kd for loc regression

 这里的loc regression采用的是l2 loss.

        # teach1  这里的s指student,t指真实值
        loc_teach1_p = loc_teach1[pos_idx].view(-1, 4)  # loc_teach1_p = tensor<(3, 4), float32, cuda:0, grad>
        l2_dis_s = (loc_p - loc_t).pow(2).sum(1)  # Σ(loc_p-loc_t)² 计算学生L2 loss,(学生预测loc-真实标签)²  sum(1)求行和  l2_dis_s = tensor<(3,), float32, cuda:0, grad>
        l2_dis_s_m = l2_dis_s + self.reg_m  # l2_dis_s_m = tensor<(3,), float32, cuda:0, grad>
        l2_dis_t = (loc_teach1_p - loc_t).pow(2).sum(1)  # L2 loss:(老师loc预测值-真实标签)²并求和  l2_dis_t = tensor<(3,), float32, cuda:0, grad>
        l2_num = l2_dis_s_m > l2_dis_t  # 判断学生位置回归与真实reg距离 和 老师位置回归与真实标签距离 的大小  l2_num = tensor<(3,), bool, cuda:0>
        l2_loss_teach1 = l2_dis_s[l2_num].sum()  # 当学生大于老师 Lb(Rs,Rt,y)=Σ(loc_p-loc_t)²,否则为0 Lb表示文章定义的teacher bounded regression loss
                                                 # l2_loss_teach1 = tensor<(), float32, cuda:0, grad> 取出l2_num=True的
        l2_loss = l2_loss_teach1  # l2_loss = tensor<(), float32, cuda:0, grad>

kd for conf regression

conf_p是预测值,ps是student的分类预测,pt是teacher的分类预测,计算两者loss。

        # soft loss for Knowledge Distillation
        # teach1
        conf_p_teach = conf_teach1[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
        pt = F.softmax(conf_p_teach / self.Temp, dim=1)
        if self.neg_w > 1.:
            ps = F.softmax(conf_p / self.Temp, dim=1)
            soft_loss1 = KL_div(pt, ps, self.pos_w, self.neg_w) * (self.Temp ** 2)
        else:
            ps = F.log_softmax(conf_p / self.Temp, dim=1)
            soft_loss1 = nn.KLDivLoss(size_average=False)(ps, pt) * (self.Temp ** 2)
        soft_loss = soft_loss1

最后返回有4个loss,soft_loss是分类kd loss,l2loss是loc 蒸馏loss,loss_c, loss_l均为hard loss中的student自己的loss。 

        loss_l /= N
        loss_c /= N
        l2_loss /= N
        soft_loss /= N
        return soft_loss, l2_loss, loss_c, loss_l

 然后可以根据自己的情况给不同的loss分配不同的权重进行训练。

            soft_loss, l2_loss, loss_c, loss_l = criterion(out_student, teacher_out, targets)  # KD损失函数
            # loss_l, loss_c = criterion1(out_student, targets) # criterion1原损失函数

            loss = (0.3 * soft_loss + 0.7 * loss_c) + (0.5 * l2_loss + loss_l)

训练如下(这里为了方便演示我这里只放了100张图片训练): 

Epoch 9/10: 100%|██████████| 54/54 [00:19<00:00,  2.80it/s, conf_loss=2.37, loc_loss=0.752, lr=7.16e-5]
Start Teacher Validation
Epoch 9/10: 100%|██████████| 6/6 [00:02<00:00,  2.01it/s, conf_loss=2.24, loc_loss=0.669, lr=7.16e-5]
Finish Teacher Validation
Epoch:9/10
Total Loss: 3.0658 || Val Loss: 2.4932 
Saving state, iter: 9
Start Teacher Train
Epoch 10/10: 100%|██████████| 54/54 [00:19<00:00,  2.79it/s, conf_loss=2.28, loc_loss=0.73, lr=6.59e-5]
Epoch 10/10:   0%|          | 0/6 [00:00<?, ?it/s<class 'dict'>]Start Teacher Validation
Epoch 10/10: 100%|██████████| 6/6 [00:03<00:00,  1.97it/s, conf_loss=2.14, loc_loss=0.691, lr=6.59e-5]
Finish Teacher Validation
Epoch:10/10
Total Loss: 2.9564 || Val Loss: 2.4282 
Saving state, iter: 10
开始蒸馏训练
Loading weights into state dict...
Finished!
Epoch 1/2:   0%|          | 0/54 [00:00<?, ?it/s<class 'dict'>]Start teacher2student_KD Train
Epoch 1/2: 100%|██████████| 54/54 [00:20<00:00,  2.60it/s, conf_loss=3.19, l2_loss=8.29, loc_loss=2.91, lr=0.0005, soft_loss=3.9]
Start Teacher2student_KD Validation
Epoch 1/2: 100%|██████████| 6/6 [00:02<00:00,  2.12it/s, conf_loss=2.75, loc_loss=2.56, lr=0.0005]
Finish teacher2student_KD Validation
Epoch:1/2
Total Loss: 17.9524 || Val Loss: 4.5457 
Saving state, iter: 1
Start teacher2student_KD Train
Epoch 2/2: 100%|██████████| 54/54 [00:20<00:00,  2.58it/s, conf_loss=2.99, l2_loss=6.52, loc_loss=2.53, lr=0.00046, soft_loss=3.41]
Start Teacher2student_KD Validation
Epoch 2/2: 100%|██████████| 6/6 [00:02<00:00,  2.14it/s, conf_loss=2.65, loc_loss=2.68, lr=0.00046]
Finish teacher2student_KD Validation
Epoch:2/2
Total Loss: 15.1709 || Val Loss: 4.5685 
Saving state, iter: 2
Epoch 3/4:   0%|          | 0/54 [00:00<?, ?it/s<class 'dict'>]Start teacher2student_KD Train
Epoch 3/4: 100%|██████████| 54/54 [00:25<00:00,  2.12it/s, conf_loss=2.66, l2_loss=6.92, loc_loss=2.63, lr=0.0001, soft_loss=3.05]
Epoch 3/4:   0%|          | 0/6 [00:00<?, ?it/s<class 'dict'>]Start Teacher2student_KD Validation
Epoch 3/4: 100%|██████████| 6/6 [00:02<00:00,  2.21it/s, conf_loss=2.39, loc_loss=2.66, lr=0.0001]
Finish teacher2student_KD Validation
Epoch:3/4
Total Loss: 14.9715 || Val Loss: 4.3286 
Saving state, iter: 3
Start teacher2student_KD Train
Epoch 4/4: 100%|██████████| 54/54 [00:25<00:00,  2.12it/s, conf_loss=2.44, l2_loss=6.46, loc_loss=2.54, lr=9.2e-5, soft_loss=2.84]
Start Teacher2student_KD Validation
Epoch 4/4: 100%|██████████| 6/6 [00:02<00:00,  2.15it/s, conf_loss=2.38, loc_loss=2.57, lr=9.2e-5]
Finish teacher2student_KD Validation
Epoch:4/4
Total Loss: 14.0245 || Val Loss: 4.2435 
Saving state, iter: 4

注:离线蒸馏训练对于teacher model也是有要求的,我这里的teacher model只是随便在原model的基础上改了一下训练而已,我这里仅仅是演示一下,具体的改进等需要自己去不断尝试。因此kd的好坏是取决于两个模型的。

大家也可以尝试其他的蒸馏方式,有问题可评论留言~~欢迎支持

猜你喜欢

转载自blog.csdn.net/z240626191s/article/details/128759731
今日推荐