yolov5的parse_model函数解析

# 基于YOLOV5-6.0版本
# 解析模型的配置文件(yml)以字典的形式读取yml文件,分析可得yml文件一共6部分
# 分别对应:nc depth_multiple width_multiple anchors
# backbone head
def parse_model(d, ch):  # model_dict, input_channels(3)
    LOGGER.info(f"\n{
      
      '':>3}{
      
      'from':>18}{
      
      'n':>3}{
      
      'params':>10}  {
      
      'module':<40}{
      
      'arguments':<30}")
    # 根据字典的键值对进行取值
    anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
    # 进行判断anchors是不是一个列表 6//2=3
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)

    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args, 两个字典相加合并为一个字典
        # 判断模型配置文件的module是否是字符串,eval()函数用来执行一个字符串表达式,并返回表达式的值。
        # 写这一段代码干嘛????
        m = eval(m) if isinstance(m, str) else m  # eval strings
        for j, a in enumerate(args):
            try:
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings
            except NameError:
                pass
        # 控制模型深度的代码,n=1不大于1,则n=1,否则如果n=6>1,则6*gd=6*0.5取整等于3
        n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain
        # 重新确定一下网络中每一层的输入、输出通道大小,对网络中每一层参数重新判别更改,最主要是确定c2大小
        if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                 BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
            # f=-1
            c1, c2 = ch[f], args[0]
            # 这里的no不是NO的意思,而是一个变量no = na * (nc + 5)
            # 只要我的输出通道不是最后一层,我都要给输出通道数变为靠近8的整数倍
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)  # 控制宽度的代码(卷积核个数)
            # *arg表示接收任意数量的参数,调用时会将实际参数打入一个元组传入实参
            # 原始的arg参数进行了解析,并且重新构造了args列表,因为原始传入的第一个参数为输出通道数,并非最终的通道数,并且args列表新增c1
            args = [c1, c2, *args[1:]]
            # 判断是否重复几次,可见提供的也就以下几个模块
            if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
                # 参数列表在第2个位置新增模块循环的次数数据
                args.insert(2, n)  # number of repeats
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            # c1,c2为通道数大小,ch为一个列表,存储每一层通道数的数值
            c2 = sum(ch[x] for x in f)
        elif m is Detect:
            # 获取第17.20.23层的通道数,[ch[x] for x in f]这表示里面是个列表,元素值就是ch[17],ch[20],ch[23],
            # 即args.append([128,256,512]),新增了一个元素,该元素是3个值的列表
            args.append([ch[x] for x in f])
            # 正常情况不会进if语句,args[1] = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]]
            # isinstance(args[1], int) = False
            # 如果args[1] = 10,则isinstance(args[1], int) = True
            if isinstance(args[1], int):  #  anchors
                # 重新构造出预设anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
        elif m is Contract:
            # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
            c2 = ch[f] * args[0] ** 2
        elif m is Expand:
            c2 = ch[f] // args[0] ** 2
        else:
            c2 = ch[f]
        # 对当前模块序列化之后的层传递给m_
        # 如果n=1,说明该模块就一个,不重复,直接 m_ = m(*args)
        # 如果n>1,说明模块重复了几次,m_ = nn.Sequential(*(m(*args) for _ in range(n)))
        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        # --------------------------输出到terminal的模型信息,进行记录输出----------------------------------
        t = str(m)[8:-2].replace('__main__.', '')  # module type,比如models.common.C3TR
        np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        LOGGER.info(f'{
      
      i:>3}{
      
      str(f):>18}{
      
      n_:>3}{
      
      np:10.0f}  {
      
      t:<40}{
      
      str(args):<30}')  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        # ----------------------------------------!!!------------------------------------------------
        # layers为整个model的层,将模块的层添加到整个layers层
        layers.append(m_)
        # ch在此定义,即每一层输出通道数的列表,一共24个元素,对应yolov5s-24层每层通道数
        if i == 0:
            ch = []
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

猜你喜欢

转载自blog.csdn.net/a18838956649/article/details/121449089
今日推荐