YOLO V6系列(三) -- 损失函数的计算

YOLO V6系列(三) – 损失函数的计算

在上篇blogYOLO V6系列(二) – 网络结构解析里面大概介绍了美团视觉出的YOLO V6算法的网络结构,这篇主要解析下YOLO V6算法的损失函数的计算过程以及实现代码


首先是core/engine.pytrain方法调用了其中train_in_loop类方法,接着调用的是train_in_steps类方法,在该方法实现函数代码中total_loss, loss_items = self.compute_loss(preds, targets)就是损失函数的计算。先声明下,本文中,为了解释清楚,batch_size选择的是2。

preds = self.model(images)

在这里插入图片描述
上面这一行代码是通过YOLO V6的特征提取网络得到的预测值,通过一个列表来存储三个预测头所得到的预测值,不出意外的话,从上到下来说,每个tensor的shape应该是【2,1,80,80,6】,【2,1,40,40,6】,【2,1,20,20,6】,其中2表示的是batch_size,80,40,20表示的是通过PANet结构得到的不同维度,6表示的是(C + 4 + 1),C表示网络的类别数(这里只有一个类别,所以c=1),4表示位置信息,1表示预测框包含物体的概率大小。


然后,进入ComputeLoss类,直接调用__call__()方法,参数一共有两个:outputstargets,其中第一个是上述我们说的预测值,而后者就是相对应的图片的标签,这里值得注意的是targets已经是经过resize转换之后的标签大小了。
在这里插入图片描述
上图就是debug的时候第一轮第一个批次训练的targets,这里第一个维度表示是在一个批次batch_size中的index,换句话说就是属于哪张图的标签。第二个维度表示的是类别,这里算法就是单类别,所以都是0。后面四个维度表示就是经过resize转换之后所获取的人工标注框的大小。


创建相应的损失函数之后,调用outputs, outputs_origin, gt_bboxes_scale, xy_shifts, expanded_strides = self.get_outputs_and_grids(outputs, self.strides, dtype, device)类方法,从而调用decode_output函数。其中,

yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype).to(device)

上述两行代码的意义就是将图像划分为单元网格。将output按照相对应的维度进行划分:bbox_preds、obj_preds、cls_preds。


loss_iou += (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks].T, reg_targets)).sum() / num_fg
loss_l1 += (self.l1_loss(bbox_preds_org.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg
loss_obj += (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets*1.0)).sum() / num_fg
loss_cls += (self.bcewithlog_loss(cls_preds.view(-1, num_classes)[fg_masks], cls_targets)).sum() / num_fg
total_losses = self.reg_weight * loss_iou + loss_l1 + loss_obj + loss_cls

其中,iou_loss是进行位置信息的损失函数计算,YOLO V6中使用的是siou损失函数。下面是实现代码。

            # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
            s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5
            s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5
            sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
            sin_alpha_1 = torch.abs(s_cw) / sigma
            sin_alpha_2 = torch.abs(s_ch) / sigma
            threshold = pow(2, 0.5) / 2
            sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
            angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
            rho_x = (s_cw / cw) ** 2
            rho_y = (s_ch / ch) ** 2
            gamma = angle_cost - 2
            distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
            omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
            omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
            shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
            iou = iou - 0.5 * (distance_cost + shape_cost)
        	loss = 1.0 - iou

猜你喜欢

转载自blog.csdn.net/weixin_42206075/article/details/125651920