Resnext就是一种典型的混合模型,有基础的inception+resnet组合而成,通过学习这个模型,你也可以通过以往学习的模型组合,我们每次去学习掌握一个模型的精髓就是为了融合创造新的模型。
第一步先了解下图的含义
这是resnext的三种结构,这三种结构是等价的,但是©这种结构代码容易构造,所以代码以(c)的讲解。resnext的本质在与gruops分组卷积,在之前的mobilenet网络我有先讲解这个分组的用法mobilenet网络的讲解在这里我就不再讲解groups,总之(a)是将卷积分成32个通道卷积之后相加,nn.Conv2d中的groups这个参数自动为我们分组,编写代码提供了方便。
仔细的观看,resnet里面除了通道数与resnext不同其他参数完全相同,可以看我之前写的resnet的详细讲解是一样的,这里我在简单描述一下大概的过程,图片先经过conv1,在经过pool1,然后进行第一次conv2,仔细看图中的output大小没有变,所以一会设置stride=1,之后再重复进行conv2二次,在进行conv3的时候,output有变化,所以第一次进行conv3的时候stride=2,特征图变为原来二分之一,之后再重复的三次,stride=1,特征图没有变化。后面一样。
self.conv2 = self._make_layer(64,256,1,num=layer[0])
self.conv3 = self._make_layer(256,512,2,num=layer[1])
self.conv4 = self._make_layer(512,1024,2,num=layer[2])
self.conv5 = self._make_layer(1024,2048,2,num=layer[3])
所以这里conv2中的stride=1,conv3,conv4,conv5的stride=2,进行特征图减半。图中的通道数很有规律,基本成二倍的关系,输入和输出,conv2的不一样。
与上图结构对应
全部代码
class Block(nn.Module):
def __init__(self,in_channels, out_channels, stride=1, is_shortcut=False):
super(Block,self).__init__()
self.relu = nn.ReLU(inplace=True)
self.is_shortcut = is_shortcut
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels // 2, kernel_size=1,stride=stride,bias=False),
nn.BatchNorm2d(out_channels // 2),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=3, stride=1, padding=1, groups=32,
bias=False),
nn.BatchNorm2d(out_channels // 2),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(out_channels // 2, out_channels, kernel_size=1,stride=1,bias=False),
nn.BatchNorm2d(out_channels),
)
if is_shortcut:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride,bias=1),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
x_shortcut = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
if self.is_shortcut:
x_shortcut = self.shortcut(x_shortcut)
x = x + x_shortcut
x = self.relu(x)
return x
需要注意的点只有一个self.shortcut = nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride,bias=1),
nn.BatchNorm2d(out_channels)这个的使用,它只是用一次,在每个版块,conv2开始使用一次,后面重复的二次卷积都没用使用。以conv2讲解,本质是进过此卷积,第一次已经将浅层的特征利用过了。后面重复的二次卷积(256,256),(256,256),特征图的输入和输出一样,所以这次shortcut即使使用也没用效果。
import torch
import torch.nn as nn
class Block(nn.Module):
def __init__(self,in_channels, out_channels, stride=1, is_shortcut=False):
super(Block,self).__init__()
self.relu = nn.ReLU(inplace=True)
self.is_shortcut = is_shortcut
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels // 2, kernel_size=1,stride=stride,bias=False),
nn.BatchNorm2d(out_channels // 2),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=3, stride=1, padding=1, groups=32,
bias=False),
nn.BatchNorm2d(out_channels // 2),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(out_channels // 2, out_channels, kernel_size=1,stride=1,bias=False),
nn.BatchNorm2d(out_channels),
)
if is_shortcut:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride,bias=1),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
x_shortcut = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
if self.is_shortcut:
x_shortcut = self.shortcut(x_shortcut)
x = x + x_shortcut
x = self.relu(x)
return x
class Resnext(nn.Module):
def __init__(self,num_classes,layer=[3,4,6,3]):
super(Resnext,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.conv2 = self._make_layer(64,256,1,num=layer[0])
self.conv3 = self._make_layer(256,512,2,num=layer[1])
self.conv4 = self._make_layer(512,1024,2,num=layer[2])
self.conv5 = self._make_layer(1024,2048,2,num=layer[3])
self.global_average_pool = nn.AvgPool2d(kernel_size=7, stride=1)
self.fc = nn.Linear(2048,num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.global_average_pool(x)
x = torch.flatten(x,1)
x = self.fc(x)
return x
def _make_layer(self,in_channels,out_channels,stride,num):
layers = []
block_1=Block(in_channels, out_channels,stride=stride,is_shortcut=True)
layers.append(block_1)
for i in range(1, num):
layers.append(Block(out_channels,out_channels,stride=1,is_shortcut=False))
return nn.Sequential(*layers)
net = Resnext(10)
x = torch.rand((10, 3, 224, 224))
for name,layer in net.named_children():
if name != "fc":
x = layer(x)
print(name, 'output shaoe:', x.shape)
else:
x = x.view(x.size(0), -1)
x = layer(x)
print(name, 'output shaoe:', x.shape)