torch.chunk与nn.Conv2d groups

torch.chunk

切分
假如特征x大小为:32x64x224x224 (BxCxHxW)
q = torch.chunk(x, 8, dim=1)
x是要切分的特征,8是要切分成几块,dim是指定切分的维度,这里等于1,就是按通道切分
就会将其按照通道,切分为8块,那么每一块就是32x8x224x224
返回的q是一个元组,将这八块放在元组里面

nn.Conv2d groups

图片来自:https://www.bilibili.com/video/BV1SL411V7dc?p=2&vd_source=20756f1667908eb0bfec8057bec3fb85
groups默认值是1,就是分为一组,相当于没分组
比如:输入特征大小为1x6x5x5,卷积核大小为9x6x4x4,则输出特征为1x9x2x2,此时groups=1
在这里插入图片描述
如果将groups设置为3
那么会将输出通道和输入通道都分为三组
输入特征就是三组1x2x5x5,输出特征就是三组1x3x2x2,相互对应。每一组都有一个3x2x4x4的卷积核
这三个卷积核拼起来就是9x2x4x4(可以看到,参数量减少了)
在这里插入图片描述
设置groups时要注意,它必须是in_channels和out_channels的公约数。

猜你喜欢

转载自blog.csdn.net/holly_Z_P_F/article/details/128370943