1.获得真值标签用于计算损失
从数据集获得的真值标签为整个样本的标签,而在训练过程中预测的标签是每一个特征图上每一个像素的(x,y,w,h,c),因此需要把对每一个特征图上每一个像素制作相应真值标签.
首先,初始化真值标签数组.
nB = target.size(0) #batch_size
nA = num_anchors #锚点数
nC = num_classes #类别数
nG = grid_size #网格特征图大小
mask = torch.zeros(nB, nA, nG, nG) # (batch_size,3,13/26/52,13/26/52)
conf_mask = torch.ones(nB, nA, nG, nG)
tx = torch.zeros(nB, nA, nG, nG)
ty = torch.zeros(nB, nA, nG, nG)
tw = torch.zeros(nB, nA, nG, nG)
th = torch.zeros(nB, nA, nG, nG)
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0)
对每一个网格制作标签
# target存储相对坐标,所以还原需要乘上特征图大小
gx = target[b, t, 1] * nG
gy = target[b, t, 2] * nG
gw = target[b, t, 3] * nG
gh = target[b, t, 4] * nG
gi = int(gx) #网格坐标
gj = int(gy)
gt_box = torch.FloatTensor(np.array([0, 0, gw, gh])).unsqueeze(0)
计算真值标签和锚节点矩阵的IOU,用与判断最优预测是哪一个(batch_size,3,13,13,85) 中的3种prior位置
其中IoU计算在之前的博客有详细介绍: https://blog.csdn.net/a362682954/article/details/82896242
anchor_shapes = torch.FloatTensor(np.concatenate(
(np.zeros((len(anchors), 2)), np.array(anchors)), 1))
# Calculate iou between gt and anchor shapes
anch_ious = bbox_iou(gt_box, anchor_shapes)
# Where the overlap is larger than threshold set mask to zero (ignore)
conf_mask[b, anch_ious > ignore_thres, gj, gi] = 0
# Find the best matching anchor box
best_n = np.argmax(anch_ious)
# Get ground truth box
gt_box = torch.FloatTensor(np.array([gx, gy, gw, gh])).unsqueeze(0)
# Get the best prediction
pred_box = pred_boxes[b, best_n, gj, gi].unsqueeze(0)
# Masks,用于找到最高重叠率的预测窗口
mask[b, best_n, gj, gi] = 1
conf_mask[b, best_n, gj, gi] = 1
并计算相对网格点坐标,类别和置信度.
# 真值标签相对网格点坐标
tx[b, best_n, gj, gi] = gx - gi
ty[b, best_n, gj, gi] = gy - gj
# Width and height
tw[b, best_n, gj, gi] = math.log(gw / anchors[best_n][0] + 1e-16)
th[b, best_n, gj, gi] = math.log(gh / anchors[best_n][1] + 1e-16)
# One-hot encoding of label
target_label = int(target[b, t, 0])
tcls[b, best_n, gj, gi, target_label] = 1
tconf[b, best_n, gj, gi] = 1