Focus模块代码解析

Focus模块代码解析

在这里插入图片描述
这就是Focus结构想干的事情。把一个特征图:W * H * C 变到 W/2 * H/2 * C*4。取像素点的过程就是隔一个格子取一个点(横竖都间隔一个)。
在YOLOv5中,就是把一个640× 640 × 3的图片变为了320×320×12的图片了,很好理解吧,主要是这个过程的编程。
Yolov5中Focus的代码如下:

class Focus(nn.Module):
    # Focus wh information into c-space
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
        # self.contract = Contract(gain=2)

    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
        return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
        # return self.conv(self.contract(x))

这个代码最重要的就是这一句。
self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
卷积就不说了,先说一下torch.cattorch.cat(input,dim),input为输入,dim为拼接的维度,这个函数就是将多个tensor拼接起来,在第一个维度,也就是通道数concat,也就是把下面四句函数表示的tensor拼接起来。

x[..., ::2, ::2]
x[..., 1::2, ::2]
x[..., ::2, 1::2]
x[..., 1::2, 1::2]

这里我举个例子,

a = torch.randn(4,4)
b = a[..., ::2, ::2]  
a的结果如下:
tensor([[ 1.2897, -0.2437, -0.5905,  1.0346],
        [ 0.2147,  0.7480, -0.4435, -1.1310],
        [ 0.2540,  1.8344, -0.2806,  1.3506],
        [-0.1076, -0.7443,  1.6323,  1.0701]])
b的结果如下:
tensor([[ 1.2897, -0.5905],
        [ 0.2540, -0.2806]])

将4*4的tensor分为4块,这句代码的意思就是取到每一块的左上角的值,达到focus的目的。剩下三句代码也就可以想象到了,最后cat在一起得到Focus之后的特征图。

猜你喜欢

转载自blog.csdn.net/JiatongForever/article/details/127409458
今日推荐