SKNet论文

论文连接:Selective Kernel Networks
code:pytorch代码

SENet主要受神经学的启发,视觉皮层中神经元可以根据不同的刺激动态的调整自身的RF(receptive field,感受野)。其实在之前的很多网络中都采用了多个不同大小kernel,然后将他们提取到的特征进行融合,比如Inception采用了 3 × 3 3 \times 3 , 5 × 5 5 \times 5 , 7 × 7 7 \times 7 的kernel。
论文提出的一种使用非线性的方法来整合多个kernel提取的信息。作者提出了“Selective Kernel”(SK)卷积,其包含Split,FuseSelect操作。
图1
上图采用了2个大小不同的卷积核,在实际应用中可以轻松扩展到多个卷积核。

Split

如上图所示,采用两个不同的卷积核分别对 X X 进行卷积操作,得到 U ^ \hat{U} U ~ \tilde{U} ,两者的大小( H × W × C H \times W \times C )相同.

Fuse

1. U ^ \hat{U} U ~ \tilde{U} 对应元素相加,得到 U ( H × W × C ) U(H \times W \times C)
U = U ^ + U ~ U = \hat{U}+\tilde{U}
2.对 U U 进行global average pooling,得到 S S S S U U 的通道数一样,均为 C C

3.使用全连接层对 S S 进行压缩,得到 Z Z Z Z 的通道数为 d d .
z = F f c ( s ) = δ ( B ( W s ) ) z = F_{fc}(s)=\delta(B(\bm{W}s))
δ \delta 为ReLU函数, B B 为Batch Normalization。
d = m a x ( C / r , L ) d = max(C/r,L)
r r 为缩放比例, L L 为通道的最小值,论文中设置为32.

Select

a c = e A c z e A c z + e B c z , b c = e B c z e A c z + e B c z a_{c}=\frac{e^{\mathbf{A}_{c} \mathbf{z}}}{e^{\mathbf{A}_{c} \mathbf{z}}+e^{\mathbf{B}_{c} \mathbf{z}}}, b_{c}=\frac{e^{\mathbf{B}_{c} \mathbf{z}}}{e^{\mathbf{A}_{c} \mathbf{z}}+e^{\mathbf{B}_{c} \mathbf{z}}}

其中 A , B R C × d \mathbf{A}, \mathbf{B} \in \mathbb{R}^{C \times d} a , b \mathbf{a},\mathbf{b} 分别代表 U ^ \hat{U} U ~ \tilde{U} 的soft attention vector。
在这里 a c + b c = 1 a_c+b_c=1
V c = a c U c ^ + b c U c ~ V_c = a_c \cdot \hat{U_c} + b_c \cdot{\tilde{U_c}}
V = [ V 1 , V 2 , , V C ] \mathbf{V} = [\mathbf{V_1},\mathbf{V_2},\dots,\mathbf{V_C} ]

网络结构

SKnet网络结构
SKNet由多个SK Unit组合而成,每个SK Unit包含 1 × 1 1 \times 1 卷积,SK 卷积, 1 × 1 1 \times 1 卷积。

实验

具体实验部分请参考论文,作者通过多个实验证明了SKNet的效果。

代码

import torch
from torch import nn

class SKConv(nn.Module):
	def __init__(self, features, M, G, r, stride=1, L=32):
		"""
		:param features: input channel dimensionality
		:param WH: input spatial dimensionality, used for GAP kernel size.
		:param M:the number of branchs.
		:param G:number of convolution group
		:param r:the radio for compute d, the length of z
		:param stride: stride, default 1.
		:param L:the minimum dim of vector z in paper, default 32.
		"""
		super(SKConv, self).__init__()
		d = max(int(features/r), L)
		self.M = M
		self.features = features
		self.convs = nn.ModuleList([])
		for i in range(self.M):
			self.convs.append(nn.Sequential(
				nn.Conv2d(features, features, kernel_size=3+i*2, stride=stride, padding=1+i, groups=G),
				nn.BatchNorm2d(features),
				nn.ReLU(inplace=False)
			))
		self.gap = nn.AdaptiveAvgPool2d(1)
		self.fc = nn.Linear(features, d)
		self.fcs = nn.ModuleList([])
		for i in range(M):
			self.fcs.append(
				nn.Linear(d, features)
			)
		self.softmax = nn.Softmax(dim=1)

	def forward(self, x):
		for i, conv in enumerate(self.convs):
			fea = conv(x).unsqueeze_(dim=1)
			if i == 0:
				feas = fea
			else:
				feas = torch.cat([feas, fea], dim=1)
		fea_U = torch.sum(feas, dim=1)
		fea_s = self.gap(fea_U).squeeze_(-1).squeeze_(-1)
		fea_z = self.fc(fea_s)
		for i, fc in enumerate(self.fcs):
			vector = fc(fea_z).unsqueeze_(dim=1)
			if i == 0:
				attention_vectors =vector
			else:
				attention_vectors = torch.cat([attention_vectors, vector], dim=1)
		attention_vectors = self.softmax(attention_vectors)
		attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
		fea_v = (feas * attention_vectors).sum(dim=1)
		return fea_v


if __name__ == '__main__':
	torch.manual_seed(1)
	x = torch.rand(8, 64, 32, 32)
	conv = SKConv(64,3,8,2)
	out = conv(x)
	print(x.shape)


PyTorch中的 XXX_ 和 XXX 实现的功能都是相同的,唯一不同的是前者进行的是 in_place 操作。

参考

SKNet——SENet孪生兄弟篇

发布了93 篇原创文章 · 获赞 0 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/Dream_xd/article/details/104540203
今日推荐