卷积块CBM(conv+batchnorm+mish)

卷 积 块 C B M ( c o n v + b a t c h n o r m + m i s h ) 卷积块CBM(conv+batchnorm+mish) CBM(conv+batchnorm+mish)

Conv:提取特征
BN:1.防止梯度消失 2.防止过拟合 3.促进收敛
Mish:更优秀的激活函数,相比于其他,可以更有效的防止梯度消失

在这里插入图片描述

class Mish(nn.Module):
    def __init__(self):
        super(Mish, self).__init__()

    def forward(self, x):
        return x * torch.tanh(F.softplus(x))# F.softplus(x) == torch.log(1+torch.exp(x))


class CBM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        super(CBM, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.activation = Mish()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x

测试

rgb = torch.randn(1, 3, 32, 32) # (batchsize,channel,w,h)
#print(rgb)
print(rgb.shape)

在这里插入图片描述

test_downsample_conv = BasicConv(3, 1,3,stride=2)
x = test_downsample_conv(rgb)
#print(x)
print(x.shape)

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_41375318/article/details/114456520
今日推荐