YOLOX: pytorch implements network structure
Introduction to the overall structure of the network
YOLOX's network is mainly composed of three parts, namely CSPDarkNet, FPN and YOLOXHead.
-
CSPDarkNet is the backbone feature extraction network of YOLOX, and its output is three effective feature layers
-
FPN is YOLOX's enhanced feature extraction network. Its function is to fuse the three effective feature layers output by CSPDarkNet, and to fuse feature information of different scales.
-
YOLOXHead is the classifier and regressor of YOLOX. YOLOXHead uses the three feature maps output by FPN to judge whether there is an object corresponding to the feature point. The previous YOLOXHead implemented classification and regression in one convolution, and YOLOXHead will classify and return. Implemented separately and finally integrated together
Pytorch implementation of backbone network CSPDarkNet
The CSPDarkNet backbone network is mainly composed of Focus, CSPLayer, BaseConv, DWConv, Bottleneck, and SPPBottleneck. Its network structure is shown in the figure below:
Basic module
BaseConv
BaseConvert consists of a convolutional layer, a BN layer and an activation layer
class BaseConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
ksize,
stride,
groups=1,
bias=False,
act="silu"):
super(BaseConv, self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=ksize,
stride=stride,
groups=groups,
padding=(ksize - 1) // 2,
bias=bias)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
self.act = get_activation(act, inplace=True)
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def get_activation(name="silu", inplace=True):
if name == "silu":
module = SiLU()
elif name == "relu":
module = nn.ReLU(inplace=inplace)
elif name == "lrelu":
module = nn.LeakyReLU(0.1, inplace=inplace)
else:
raise AttributeError("Unsupported act type: {}".format(name))
return module
DWConv
Depth separable convolution, consisting of a channel-by-channel convolution and a point-by-point convolution, specifically introduced
Depthwise Separable Convolution
class DWConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
ksize,
stride=1,
act="silu"):
super(DWConv, self).__init__()
self.dconv = BaseConv(in_channels=in_channels,
out_channels=in_channels,
ksize=ksize,
stride=stride,
groups=in_channels,
act=act)
self.pconv = BaseConv(in_channels=in_channels,
out_channels=out_channels,
ksize=1,
stride=1,
groups=1,
act=act)
def forward(self, x):
return self.pconv(self.dconv(x))
Focus network implementation
The Focus network takes the input picture from the first row and the first column, the first row and the second column, the second row and the first column, and the second row and the second example, and takes out a pixel every other pixel to form a new picture. You will get four images whose length and width are reduced by half, and then stack them together so that the number of channels becomes 4 times the original, that is, 12
Using python slices can easily intercept four independent feature layers, and then use torch.cat to stack them
class Focus(nn.Module):
def __init__(self,
in_channels,
out_channels,
ksize=1,
stride=1,
act="silu"):
super(Focus, self).__init__()
self.conv = BaseConv(in_channels=in_channels * 4,
out_channels=out_channels,
ksize=ksize,
stride=stride,
act=act)
def forward(self, x):
patch_top_left = x[..., ::2, ::2]
patch_bot_left = x[..., 1::2, ::2]
patch_top_right = x[..., ::2, 1::2]
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1)
return self.conv(x)
CSPLayer network implementation
The CSPLayer network is stacked by the residual network Bottleneck. The main part of the Bottleneck is composed of a 1X1 convolution and a 3X3 convolution. The residual side part is not processed, and finally the two are added together.
class Bottleneck(nn.Module):
def __init__(self,
in_channels,
out_channels,
shortcut=True,
expansion=0.5,
depthwise=False,
act="silu"):
super(Bottleneck, self).__init__()
hidden_channels = int(out_channels * expansion)
Conv = DWConv if depthwise else BaseConv
self.conv1 = BaseConv(in_channels=in_channels,
out_channels=hidden_channels,
ksize=1,
stride=1,
act=act)
self.conv2 = Conv(in_channels=hidden_channels,
out_channels=out_channels,
ksize=3,
stride=1,
act=act)
self.use_add = shortcut and in_channels == out_channels
def forward(self, x):
y = self.conv2(self.conv1(x))
if self.use_add:
y = y + x
return y
The main part of CSPLayer continues to stack the original residual block; the other part is like a residual edge, which is directly connected to the end after a small amount of processing. Therefore, it can be considered that there is a large residual edge in CSP.
class CSPLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
n=1,
shortcut=True,
expansion=0.5,
depthwise=False,
act="silu"):
super(CSPLayer, self).__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = BaseConv(in_channels=in_channels,
out_channels=hidden_channels,
ksize=1,
stride=1,
act=act)
self.conv2 = BaseConv(in_channels=in_channels,
out_channels=hidden_channels,
ksize=1,
stride=1,
act=act)
self.conv3 = BaseConv(in_channels=hidden_channels * 2,
out_channels=out_channels,
ksize=1,
stride=1,
act=act)
self.m = nn.Sequential(*[Bottleneck(in_channels=hidden_channels,
out_channels=hidden_channels,
shortcut=shortcut,
expansion=1,
depthwise=depthwise,
act=act) for _ in range(n)])
def forward(self, x):
x_1 = self.conv1(x)
x_2 = self.conv2(x)
x_1 = self.m(x_1)
x = torch.cat((x_1, x_2), dim=1)
return self.conv3(x)
SPPBottleneck network implementation
The SPPBottleneck network extracts features through the maximum pooling of different pooling kernel sizes to improve the receptive field of the network
class SPPBottleneck(nn.Module):
def __init__(self,
in_channels,
out_channels,
ksizes=(5, 9, 13),
act="silu"):
super(SPPBottleneck, self).__init__()
hidden_channels = in_channels // 2
self.conv1 = BaseConv(in_channels=in_channels,
out_channels=hidden_channels,
ksize=1,
stride=1,
act=act)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ksize, stride=1, padding=ksize // 2) for ksize in ksizes])
conv2_channels = hidden_channels * (len(ksizes) + 1)
self.conv2 = BaseConv(in_channels=conv2_channels,
out_channels=out_channels,
ksize=1,
stride=1,
act=act)
def forward(self, x):
x = self.conv1(x)
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
x = self.conv2(x)
return x
CSPDarkNet network implementation
According to the figure below, use some basic module networks implemented earlier to build YOLOX's backbone network CSPDarkNet
class CSPDarknet(nn.Module):
def __init__(self,
dep_mul,
wid_mul,
out_features=("dark3", "dark4", "dark5"),
depthwise=False,
act="silu"):
super(CSPDarknet, self).__init__()
self.out_features = out_features
Conv = DWConv if depthwise else BaseConv
base_channels = int(wid_mul * 64)
base_depth = max(round(dep_mul * 3), 1)
self.stem = Focus(in_channels=3,
out_channels=base_channels,
ksize=3,
act=act)
self.dark2 = nn.Sequential(
Conv(in_channels=base_channels,
out_channels=base_channels * 2,
ksize=3,
stride=2,
act=act),
CSPLayer(in_channels=base_channels * 2,
out_channels=base_channels * 2,
n=base_depth,
depthwise=depthwise,
act=act)
)
self.dark3 = nn.Sequential(
Conv(in_channels=base_channels * 2,
out_channels=base_channels * 4,
ksize=3,
stride=2,
act=act),
CSPLayer(in_channels=base_channels * 4,
out_channels=base_channels * 4,
n=base_depth * 3,
depthwise=depthwise,
act=act)
)
self.dark4 = nn.Sequential(
Conv(in_channels=base_channels * 4,
out_channels=base_channels * 8,
ksize=3,
stride=2,
act=act),
CSPLayer(in_channels=base_channels * 8,
out_channels=base_channels * 8,
n=base_depth * 3,
depthwise=depthwise,
act=act)
)
self.dark5 = nn.Sequential(
Conv(in_channels=base_channels * 8,
out_channels=base_channels * 16,
ksize=3,
stride=2,
act=act),
SPPBottleneck(in_channels=base_channels * 16,
out_channels=base_channels * 16,
act=act),
CSPLayer(in_channels=base_channels * 16,
out_channels=base_channels * 16,
n=base_depth,
shortcut=False,
depthwise=depthwise,
act=act)
)
def forward(self, x):
outputs = {
}
x = self.stem(x)
outputs["stem"] = x
x = self.dark2(x)
outputs["dark2"] = x
x = self.dark3(x)
outputs["dark3"] = x
x = self.dark4(x)
outputs["dark4"] = x
x = self.dark5(x)
outputs["dark5"] = x
return {
k: v for k, v in outputs.items() if k in self.out_features}