1、概述
DenseNet主要借鉴ResNet中提到的残差网络(直接连接输入和输出),并利用密集连接将其发扬光大。传统的L层的神经网络,使用直接堆叠的方式,网络之间的连接数为L,而使用DenseNet则它们的连接数为:1,3,6,10…,利用公式计算可知L层的网络连接数为L*(L+1)/2,也就是说前面所有层的处理结果都会加到后面一层继续处理,当前层获取的新的特征层其实就比较有限,但是可以大大增加深浅层次特征的融合情况。
DenseNet相较于其它网络的优势是:
1)可以像ResNet一样,网络可以设计的很深,在提高准确率的同时不用担心梯度消失或者退化问题
2)可以强化特征的传递
3)增加特征的重用率
4)大幅减少网络参数
其主要的贡献模块DenseBlock如下图所示,每一层的输出除了直接传递到下一次进行处理外,都会有额外的跨层连接到其他层,大大增加特征传递和特征重用率。
从上图可以看出,密集连接的输入是前面所有层的concate,用于特征提取的模块主要是BN+ReLU+Conv。DenseBlock内部是不进行特征大小压缩的,因此pooling的方式就放到了Transition Layer里面。
x ( l ) = h ( l ) ( [ x 0 , x 1 , x 2 , . . . . x ( l − 1 ) ] ) x(l) = h(l)([x0,x1,x2,....x(l-1)]) x(l)=h(l)([x0,x1,x2,....x(l−1)])
DenseBlock内部每一层都输出相同的通道数k,因此,block内部第L个layer的输入为:
k l = k 0 + k ( l − 1 ) kl = k0+k(l-1) kl=k0+k(l−1)
其中k0是block输入的特征通道数。
2、网络模块
DenseNet的主要结构就是DenseBlock+Transition Layer,通过他们之间不同的堆叠个数就可以实现不同深度的DenseNet的构建。如下图所示为一个带有三个DenseBlock的简单DenseNet网络。
而其中的DenseBlock中的Dense Layer有两种形态,其一为下图所示,仅包含BN+ReLU+Conv,输出都是一个固定k通道的特征图。
考虑随着网络的加深,后续的输入会越来越大,为了有效的减少计算量,针对Dense Layer进行了改进,加入了bottleneck,最终的模块如下所示,
可以看到在原有的Dense Layer的基础上,加入了一个1*1的模块以降低计算量。
其具体的实现如下:
# *****************normal denselayer*************************
class DenseLayer(nn.Module):
def __init__(self, num_input_channels, growth_rate, expansion):
super(DenseLayer, self).__init__()
self.bn = nn.BatchNorm2d(num_features=num_input_channels)
self.conv = nn.Conv2d(in_channels=num_input_channels * expansion, out_channels=growth_rate, kernel_size=3,
stride=1,
padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.bn(x)
x = self.relu(x)
x = self.conv(x)
return x
# *****************bottleneck denselayer*************************
class DenseLayer_BC(nn.Module):
def __init__(self, num_input_channels, growth_rate, expansion): # 输入特征图,K(32),1*1卷积的输出通道相对于K的膨胀倍数,drop_rate
super(DenseLayer_BC, self).__init__()
self.bn1 = nn.BatchNorm2d(num_features=num_input_channels)
self.conv1_1 = nn.Conv2d(in_channels=num_input_channels, out_channels=growth_rate * expansion, kernel_size=1,
stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(num_features=growth_rate * expansion)
self.conv3_3 = nn.Conv2d(in_channels=growth_rate * expansion, out_channels=growth_rate, kernel_size=3, stride=1,
padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1_1(self.relu(self.bn1(x)))
x = self.conv3_3(self.relu(self.bn2(x)))
return x
结合如上的Dense Layer, 可以简单的实现DenseBlock如下:
# *****************denseblock*************************
class DenseBlock(nn.Module):
def __init__(self, in_channel, block, num_layers, growth_rate, expansion, drop_rate):
super(DenseBlock, self).__init__()
self.in_channel = in_channel
self.block = block
self.num_layers = num_layers
self.growth_rate = growth_rate
self.expansion = expansion
self.drop_rate = drop_rate
self.layers = self._make_layers()
def _make_layers(self):
layers = []
for i in range(self.num_layers):
layers.append(self.block(self.in_channel + self.growth_rate * i, self.growth_rate, self.expansion))
return layers
def forward(self, x):
feature = self.layers[0](x)
if self.drop_rate > 0:
feature = F.dropout(feature, p=self.drop_rate, training=self.training)
out = torch.concat([feature, x], dim=1)
for i in range(1, len(self.layers)):
feature = self.layers[i](out)
if self.drop_rate > 0:
feature = F.dropout(feature, p=self.drop_rate, training=self.training)
out = torch.cat([out, feature], dim=1)
return out
接下来就是Transition Layer,它主要是连接两个相邻的DenseBlock,压缩特征图大小。包含BN+ReLU+Conv+AvgPool。如下为其具体实现:
# *****************pooling layer*************************
class Transition(nn.Module):
def __init__(self, in_channel, compress_rate=0.5):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2d(num_features=in_channel)
self.relu = nn.ReLU(inplace=True)
self.conv1_1 = nn.Conv2d(in_channels=in_channel, out_channels=int(in_channel * compress_rate), kernel_size=1,
stride=1, bias=False)
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = self.conv1_1(self.relu(self.bn1(x)))
x = self.avg_pool(x)
return x
3、模型结构
DenseNet的网络结构如下所示,输入为224*224,首先使用较大的卷积核7*7进行卷积下采样,并进行3*3的Pooling操作。然后就是DensBlock+Transition Layer的堆叠,最后使用7*7的GAP+FC+Softmax得到分类的结果
最后附上DensNet的实现代码。
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchsummary import summary
import re
import torch.utils.model_zoo as model_zoo
# *****************pretrained model*************************
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
# *****************normal denselayer*************************
class DenseLayer(nn.Module):
def __init__(self, num_input_channels, growth_rate, expansion):
super(DenseLayer, self).__init__()
self.bn = nn.BatchNorm2d(num_features=num_input_channels)
self.conv = nn.Conv2d(in_channels=num_input_channels * expansion, out_channels=growth_rate, kernel_size=3,
stride=1,
padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.bn(x)
x = self.relu(x)
x = self.conv(x)
return x
# *****************bottleneck denselayer*************************
class DenseLayer_BC(nn.Module):
def __init__(self, num_input_channels, growth_rate, expansion): # 输入特征图,K(32),1*1卷积的输出通道相对于K的膨胀倍数,drop_rate
super(DenseLayer_BC, self).__init__()
self.bn1 = nn.BatchNorm2d(num_features=num_input_channels)
self.conv1_1 = nn.Conv2d(in_channels=num_input_channels, out_channels=growth_rate * expansion, kernel_size=1,
stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(num_features=growth_rate * expansion)
self.conv3_3 = nn.Conv2d(in_channels=growth_rate * expansion, out_channels=growth_rate, kernel_size=3, stride=1,
padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1_1(self.relu(self.bn1(x)))
x = self.conv3_3(self.relu(self.bn2(x)))
return x
# *****************denseblock*************************
class DenseBlock(nn.Module):
def __init__(self, in_channel, block, num_layers, growth_rate, expansion, drop_rate):
super(DenseBlock, self).__init__()
self.in_channel = in_channel
self.block = block
self.num_layers = num_layers
self.growth_rate = growth_rate
self.expansion = expansion
self.drop_rate = drop_rate
self.layers = self._make_layers()
def _make_layers(self):
layers = []
for i in range(self.num_layers):
layers.append(self.block(self.in_channel + self.growth_rate * i, self.growth_rate, self.expansion))
return layers
def forward(self, x):
feature = self.layers[0](x)
if self.drop_rate > 0:
feature = F.dropout(feature, p=self.drop_rate, training=self.training)
out = torch.concat([feature, x], dim=1)
for i in range(1, len(self.layers)):
feature = self.layers[i](out)
if self.drop_rate > 0:
feature = F.dropout(feature, p=self.drop_rate, training=self.training)
out = torch.cat([out, feature], dim=1)
return out
# *****************pooling layer*************************
class Transition(nn.Module):
def __init__(self, in_channel, compress_rate=0.5):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2d(num_features=in_channel)
self.relu = nn.ReLU(inplace=True)
self.conv1_1 = nn.Conv2d(in_channels=in_channel, out_channels=int(in_channel * compress_rate), kernel_size=1,
stride=1, bias=False)
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = self.conv1_1(self.relu(self.bn1(x)))
x = self.avg_pool(x)
return x
# *****************DenseNet*************************
class DenseNet(nn.Module):
def __init__(self, block, growth_rate=32, num_block=(6, 12, 24, 16), expansion=4, compression_rate=0.5,
drop_rate=0, num_classes=1000):
super(DenseNet, self).__init__()
self.growth_rate = growth_rate
self.num_block = num_block
self.num_init_features = 2 * self.growth_rate
self.expansion = expansion
self.compression_rate = compression_rate
self.drop_rate = drop_rate
self.num_classes = num_classes
self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.num_init_features, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(self.num_init_features)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.denseblocks = self._make_denseblock(block)
self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.fc = nn.Linear(self.num_init_features, num_classes)
self._initialize_weights()
def _make_denseblock(self, block):
layers = []
for i, num in enumerate(self.num_block):
layers.append(
DenseBlock(self.num_init_features, block, num, self.growth_rate, self.expansion, self.drop_rate))
self.num_init_features += num * self.growth_rate
if i != len(self.num_block) - 1:
layers.append(Transition(self.num_init_features, compress_rate=self.compression_rate))
self.num_init_features = int(self.num_init_features * self.compression_rate)
return nn.Sequential(*layers)
def _initialize_weights(self):
"""
权重初始化
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
# 卷积层使用 kaimming 初始化
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
# 偏置初始化为0
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# 批归一化层权重初始化为1
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# 全连接层权重初始化
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x))) # [1,64,112,112]
x = self.maxpool(x) # [1,64,56,56]
x = self.denseblocks(x)
x = self.avg_pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def DenseNet_121(block, expansion, pretrained=False):
# 使用DenseLayer_BC expansion需要为4,如果使用DenseLayer,expansion设定为1
model = DenseNet(block, num_block=(6, 12, 24, 16), expansion=expansion)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet121'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
def DenseNet_169(block, expansion, pretrained=False):
# 使用DenseLayer_BC expansion需要为4,如果使用DenseLayer,expansion设定为1
model = DenseNet(block, num_block=(6, 12, 32, 32), expansion=expansion)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet121'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
def DenseNet_201(block, expansion, pretrained=False):
# 使用DenseLayer_BC expansion需要为4,如果使用DenseLayer,expansion设定为1
model = DenseNet(block, num_block=(6, 12, 48, 32), expansion=expansion)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet121'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
def DenseNet_264(block, expansion, pretrained=False):
# 使用DenseLayer_BC expansion需要为4,如果使用DenseLayer,expansion设定为1
model = DenseNet(block, num_block=(6, 12, 64, 48), expansion=expansion)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet121'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
if __name__ == '__main__':
# net = DenseNet(DenseLayer_BC)
net = DenseNet_169(DenseLayer, expansion=1, pretrained=False)
x = torch.rand((1, 3, 224, 224))
out = net(x)
print(out.shape)
summary(net, (3, 224, 224), device='cpu')