数值类型张量与布尔类型张量的乘积

最近在学习yolov5源码中关于损失函数的计算,build_targets函数中,有一段代码是关于去除与指定anchor宽高比过大的gt box的,代码如下:

if nt:
   # Matches
   r = t[..., 4:6] / anchors[:, None]  # wh ratio
   j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t']  # compare 
                
   t = t[j]  # filter 过滤 

其中 t=t[j] 这句实在没懂,但由代码可知,j 是由True和False组成的张量,于是就自己写了几行代码验证看下结果。(实际上debug观察t的shape可以明白这行代码的作用)

import torch

j = torch.tensor([True, True, False, False])
t = torch.tensor([1, 2, 3, 4])

t = t[j]
print(t)

运行结果如下:

 

由此可见,j 起到了一个过滤的作用,用于 t 的有效值筛选。 

猜你喜欢

转载自blog.csdn.net/weixin_48747603/article/details/127335867