Focus模块代码解析
这就是Focus结构想干的事情。把一个特征图:W * H * C 变到 W/2 * H/2 * C*4。取像素点的过程就是隔一个格子取一个点(横竖都间隔一个)。
在YOLOv5中,就是把一个640× 640 × 3的图片变为了320×320×12的图片了,很好理解吧,主要是这个过程的编程。
Yolov5中Focus的代码如下:
class Focus(nn.Module):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super().__init__()
self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
# self.contract = Contract(gain=2)
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
# return self.conv(self.contract(x))
这个代码最重要的就是这一句。
self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
卷积就不说了,先说一下torch.cat
,torch.cat(input,dim)
,input为输入,dim为拼接的维度,这个函数就是将多个tensor拼接起来,在第一个维度,也就是通道数concat,也就是把下面四句函数表示的tensor拼接起来。
x[..., ::2, ::2]
x[..., 1::2, ::2]
x[..., ::2, 1::2]
x[..., 1::2, 1::2]
这里我举个例子,
a = torch.randn(4,4)
b = a[..., ::2, ::2]
a的结果如下:
tensor([[ 1.2897, -0.2437, -0.5905, 1.0346],
[ 0.2147, 0.7480, -0.4435, -1.1310],
[ 0.2540, 1.8344, -0.2806, 1.3506],
[-0.1076, -0.7443, 1.6323, 1.0701]])
b的结果如下:
tensor([[ 1.2897, -0.5905],
[ 0.2540, -0.2806]])
将4*4的tensor分为4块,这句代码的意思就是取到每一块的左上角的值,达到focus的目的。剩下三句代码也就可以想象到了,最后cat在一起得到Focus之后的特征图。