YOLOv5添加改进的解耦头

在看文献时,发现一篇对YOLOX的解耦头改进然后用在YOLOv5的论文,这里复现了下代码

再次声明下,并不是所有的改进用在你的数据集就会涨点,有时候也会适得其反,这个改进不行,那就换另一个,没必要认准一个死磕

一、改进解耦头结构

 为什么这么改,论文里也有提到



 二、代码部分

这里你需要去看我上一篇有关YOLOv5替换解耦头的博客,在此基础上进行的改进,代码如下

class DecoupledHead(nn.Module):
#########################   初始模块   ###########################################
    # def __init__(self, ch=256, nc=80,  anchors=()):
    #     super().__init__()
    #     self.nc = nc  # number of classes
    #     self.nl = len(anchors)  # number of detection layers
    #     self.na = len(anchors[0]) // 2  # number of anchors
    #     self.merge = Conv(ch, 256 , 1, 1)
    #     self.cls_convs1 = Conv(256 , 256 , 3, 1, 1)
    #     self.cls_convs2 = Conv(256 , 256 , 3, 1, 1)
    #     self.reg_convs1 = Conv(256 , 256 , 3, 1, 1)
    #     self.reg_convs2 = Conv(256 , 256 , 3, 1, 1)
    #     self.cls_preds = nn.Conv2d(256 , self.nc * self.na, 1)
    #     self.reg_preds = nn.Conv2d(256 , 4 * self.na, 1)
    #     self.obj_preds = nn.Conv2d(256 , 1 * self.na, 1)
    #
    # def forward(self, x):
    #     x = self.merge(x)
    #     x1 = self.cls_convs1(x)
    #     x1 = self.cls_convs2(x1)
    #     x1 = self.cls_preds(x1)
    #     x2 = self.reg_convs1(x)
    #     x2 = self.reg_convs2(x2)
    #     x21 = self.reg_preds(x2)
    #     x22 = self.obj_preds(x2)
    #     out = torch.cat([x21, x22, x1], 1)
    #     return out
    def __init__(self, ch=256, nc=80,  anchors=()):
        super().__init__()
        self.nc = nc  # number of classes
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.cls_convs1 = Conv_shortcut(ch , ch , 3, 1)
        self.reg_convs1 = Conv_shortcut(ch , ch , 3, 1)
        self.cls_preds = nn.Conv2d(ch , self.nc * self.na, 1)  # 一个1x1的卷积,把通道数变成类别数,比如coco 80类(主要对目标框的类别,预测分数)
        self.reg_preds = nn.Conv2d(ch , 4 * self.na, 1)        # 一个1x1的卷积,把通道数变成4通道,因为位置是xywh
        self.obj_preds = nn.Conv2d(ch , 1 * self.na, 1)        # 一个1x1的卷积,把通道数变成1通道,通过一个值即可判断有无目标(置信度)

    def forward(self, x):
        x1 = self.cls_convs1(x)
        x1 = self.cls_preds(x1)
        x2 = self.reg_convs1(x)
        x21 = self.reg_preds(x2)
        x22 = self.obj_preds(x2)
        out = torch.cat([x21, x22, x1], 1)  # 把分类和回归结果按channel维度,即dim=1拼接
        return out
class Conv_shortcut(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
        # # self.shortcut = nn.Sequential()
        # # if c1 != c2:
        # #     self.shortcut = nn.Sequential(
        # #         nn.Conv2d(c1, c2, kernel_size=1, stride=1, bias=False),
        # #         nn.BatchNorm2d(c2)
        # #     )

    def forward(self, x):
        y = self.bn(self.conv(x))
        # y = self.shortcut(x)
        # x1 += self.shortcut(x)
        y = self.act(y)
        return x + y

我运行没问题,有报错请在评论区评论(最近并没有研究yolo,报错可能爱莫能助)

猜你喜欢

转载自blog.csdn.net/Zeng999212/article/details/131655948