pyotrch nn.Conv2d中groups参数的理解


在pytorch的Docs中有关于nn.Conv2d的具体描述:

 torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
 

https://zhuanlan.zhihu.com/p/35405071

import torch
import torch.nn as nn
from torch.autograd import Variable

input = torch.ones(1, 3, 224, 224)
input = Variable(input)
f = nn.Conv2d(in_channels=3, out_channels=9, kernel_size=5, groups=3)
output = f(input)
print(output.shape) # (1, 9, 220, 220)


我们通过实际的例子加以说明:

# pytorch 0.3.0
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May  2 20:13:05 2018

@author: huijian
"""

# experiment about groups
import torch
import torch.nn as nn
from torch.autograd import Variable

x = torch.FloatTensor([0.1,1,10,100,1000,10000]).view(1,-1,1,1)
x = Variable(x)

conv = nn.Conv2d(in_channels=6,
                 out_channels=6,
                 kernel_size=1,
                 stride=1,
                 padding=0,
                 groups=3,
                 bias=False)

print(conv.weight.data.size())
# [1,2,3,4,5,6,7,8,9,10,11,12]
conv.weight.data = torch.arange(1,13).view(6,2,1,1)

print(conv.weight.data)

output=conv(x)
print(output)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
运行出来的结果非常不便于观察,如下

# print(conv.weight.data.size())
torch.Size([6, 2, 1, 1])
#print(conv.weight.data)
tensor([[[[  1.]],
         [[  2.]]],
        [[[  3.]],
         [[  4.]]],
        [[[  5.]],
         [[  6.]]],
        [[[  7.]],
         [[  8.]]],
        [[[  9.]],
         [[ 10.]]],
        [[[ 11.]],
         [[ 12.]]]])

#print(output)
tensor([[[[ 2.1000e+00]],
         [[ 4.3000e+00]],
         [[ 6.5000e+02]],
         [[ 8.7000e+02]],
         [[ 1.0900e+05]],
         [[ 1.3100e+05]]]]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
如果对数字比较敏感的话,大概可以看出来其实是通过groups 这个参数对输入的channel 划分成多组,像在此处,原始的输入有6个channel 被划分成了2个一组。 
我们可以看到此处卷积层的参数是[6,2,1,1] ,如果将卷积层的权重用另一种方式展示会比较清楚:

#weight[0,:,:,:]
array([[[1.]],
       [[2.]]], dtype=float32)
#weight[1,:,:,:]
array([[[3.]],
       [[4.]]], dtype=float32)
# ....
1
2
3
4
5
6
7
我们可以把思路整理如下: 
我们有6个channel的输入,我们需要6个channel的输出,如果考虑默认的卷积层,其大小应该是torch.Size([6, 6, 1, 1])也就是说每个输出的channel的计算,所有输入的channel都参与了。 
此处我们通过值可以发现,这里每个输出channel的计算每次只有2个输入的channel参与了。 
所以其结果分别是 
2.1(1x0.1+2x1)(weight[0,:,:,:]) 4.3 (3x0.1+4x1) 650 (5x10+6x100) ...

如果我们对代码进行修改,将其中的groups由3改成2,但是由于计算问题,我们将output 修改为 output/100

# output/100
tensor([[[[    0.3210]], # 30 + 2 + 0.1
         [[    0.6540]], # 60 + 5 + 0.4
         [[    0.9870]], # 90 + 8 + 0.7
         [[ 1320.0000]], # 120000 +11000 + 1000
         [[ 1653.0000]], # 150000 +14000 + 1300
         [[ 1986.0000]]]]) # 180000 +17000 +1600
1
2
3
4
5
6
7
这里我们可以看到原有的[0.1,1,10,100,1000,10000] 就被划分成为了两组 [0.1,1,10] , [100,1000,10000]。每一组被重用3次。

这之前的情况都是in_channels = out_channels 因此我们测试当in_channels != out_channels 
代码如下:

import torch
import torch.nn as nn
from torch.autograd import Variable

x = torch.FloatTensor([0.1,1,10,100,1000,10000]).view(1,-1,1,1)
x = Variable(x)

conv = nn.Conv2d(in_channels=6,
                 out_channels=18,
                 kernel_size=1,
                 stride=1,
                 padding=0,
                 groups=2,
                 bias=False)

print(conv.weight.data.size())
# [1,2,3,4,5,6,7,8,9,10,11,12]
conv.weight.data = torch.arange(1,55).view(18,3,1,1)

print(conv.weight.data)

output=conv(x)
print(output)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
我们观察结果:

# output/100
tensor([[[[    0.3210]], # 30 +2 +0.1
         [[    0.6540]],
         [[    0.9870]],
         [[    1.3200]],
         [[    1.6530]],
         [[    1.9860]],
         [[    2.3190]],
         [[    2.6520]],
         [[    2.9850]],
         [[ 3318.0000]],
         [[ 3651.0000]],
         [[ 3984.0000]],
         [[ 4317.0000]],
         [[ 4650.0000]],
         [[ 4983.0000]],
         [[ 5316.0000]],
         [[ 5649.0000]],
         [[ 5982.0000]]]])

因此我们可以理解为: 
groups 决定了将原输入分为几组,而每组channel重用几次,由out_channels/groups计算得到,这也说明了为什么需要 groups能供被 out_channels与in_channels整除。
————————————————
版权声明:本文为CSDN博主「TianJiu_23333」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/monsterhoho/article/details/80173400

发布了2614 篇原创文章 · 获赞 926 · 访问量 509万+

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/104008918
今日推荐