YOLO V3 SPP ultralytics 第三节:关于yolo 中cfg的网络配置信息和读取cfg配置文件

目录

1. 介绍

2.  关于yolo的cfg网络配置文件

2.1 关于卷积层

2.2 关于池化层

2.3 关于捷径分支shortcut

2.4 关于route 层

2.5 关于上采样层

2.6 关于yolo层

3.  解析cfg 文件

4. 代码


1. 介绍

根据 第二节 的步骤,生成了属于自己的 my_yolov3.cfg 配置文件,本章将介绍yolo 配置文件的内容以及如何读取配置文件

 部分的yolo配置文件如下:

2.  关于yolo的cfg网络配置文件

因为搭建网络的时候,是根据配置文件cfg逐步实现的,因此理解 cfg网络配置文件也很重要

首先,关于net部分是用于训练的相关配置,这里用不到

TIPS : cfg 配置文件里面的内容不要做更改,因为固定的行号是确定的。删除了一个空格的话,索引的行号就对不上了

yolo v3 spp 网络如下:

2.1 关于卷积层

卷积层的开始是:[convolutional]

其中,batch_normalize和pad的1代表是否使用这两个参数,为1代表使用

2.2 关于池化层

池化层的开始是:[maxpool]

yolo v3 spp中,只有 SPP用maxpool操作,为了实现concatenate 操作,所以要保证shape相同,因此padding 的设定就是为了这个

yolo v3 spp 下采样用卷积 stride = 2实现

2.3 关于捷径分支shortcut

捷径分支shortcut的开始是:[shortcut]

-3 代表,前面-3的输出和自己相加

shortcut 是指两个不同信息的shape相同,再相加的操作

如图所示,第一个residual就是两个矩形框的输出相加

2.4 关于route 层

route层的开始是:[route]

spp 中,需要多个信息的融合,所以route层也很重要

route 的实现类似于指针

具体的如下:
当route 只有一个值的时候,可以理解为一个指针,返回对应的层结构

当route 有多个值的时候,将对应的输出拼接

concatenate 代表在 channel 维度堆起来

2.5 关于上采样层

上采样层的开始是:[upsample]

将图像的w和h扩大两倍

2.6 关于yolo层

yolo层的开始是:[yolo]

yolo 层并不在 spp 的网络图中,是3个尺度的后处理

前三组是小目标的anchor ,以此类推

3.  解析cfg 文件

代码是 parse_config.py

首先,先读取cfg的文件,去掉空格和注释

 lines 的部分内容为,这里每样保留了一个方便观看:

[

'[net]', 'batch=64', 'subdivisions=16', 'width=608', 'height=608', 'channels=3', 'momentum=0.9', 'decay=0.0005', 'angle=0', 'saturation = 1.5', 'exposure = 1.5', 'hue=.1', 'learning_rate=0.001', 'burn_in=1000', 'max_batches = 500200', 'policy=steps', 'steps=400000,450000', 'scales=.1,.1',

'[convolutional]', 'batch_normalize=1', 'filters=32', 'size=3', 'stride=1', 'pad=1', 'activation=leaky',

'[shortcut]', 'from=-3', 'activation=linear',

'[maxpool]', 'stride=1', 'size=5',

'[route]', 'layers=-2',

'[route]', 'layers=-1,-3,-5,-6',

'[upsample]', 'stride=2',

'[yolo]', 'mask = 6,7,8', 'anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326', 'classes=20', 'num=9', 'jitter=.3', 'ignore_thresh = .7', 'truth_thresh = 1', 'random=1'

]

type 存放网络的结构,后面跟着对应的配置,存放在一个字典中

需要注意的是,有的卷积后面不跟BN层,所以矩形框里面的内容不可忽略。因为大部分conv后面有BN,虽然设定为0,后面也会被替换成本来的值

后面是对key的相关操作,key是 = 之前的东西,val是 = 后面的东西,这里主要是将 = 后面的val 数值变成int或者float类型,因为默认读取的val是str类型

parse_config.py 代码中,还有一个是解析my_data.data 文件的,具体效果如下:

4. 代码

parse_config.py 的代码为:

# 解析网络中的配置文件

import os
import numpy as np


# 解析 my_yolov3.cfg 文件
def parse_model_cfg(path: str):
    if not path.endswith(".cfg") or not os.path.exists(path):    # 检查文件是否存在
        raise FileNotFoundError("the cfg file not exist...")

    # 读取文件信息
    with open(path, "r") as f:
        lines = f.read().split("\n")

    lines = [x for x in lines if x and not x.startswith("#")]   # 去除空行和注释行
    lines = [x.strip() for x in lines]  # 去除每行开头和结尾的空格符

    mdefs = []  # module definitions
    for line in lines:
        if line.startswith("["):    # 网络层都是[]形式
            mdefs.append({})
            mdefs[-1]["type"] = line[1:-1].strip()  # type 存放网络结构,[]里面的

            if mdefs[-1]["type"] == "convolutional":    # 如果是卷积模块,设置默认不使用BN,因为有的conv后面没有BN,0代表不启用BN
                mdefs[-1]["batch_normalize"] = 0
        else:                       # 网络层的参数
            key, val = line.split("=")          # 例如,learning_rate=0.001 用等号进行分割
            key = key.strip()
            val = val.strip()

            # yolo 层
            if key == "anchors":
                val = val.replace(" ", "")  # 将空格去除
                mdefs[-1][key] = np.array([float(x) for x in val.split(",")]).reshape((-1, 2))  # (9,2) anchor
            # 特殊结构
            elif key in ["from", "layers", "mask"]:
                mdefs[-1][key] = [int(x) for x in val.split(",")]
            # 常见的正常网络参数
            else:
                if val.isnumeric():  # return int or float 如果是数值的情况
                    mdefs[-1][key] = int(val) if (int(val) - float(val)) == 0 else float(val)
                else:
                    mdefs[-1][key] = val  # return string  是字符的情况

    # check all fields are supported
    supported = ['type', 'batch_normalize', 'filters', 'size', 'stride', 'pad', 'activation', 'layers', 'groups',
                 'from', 'mask', 'anchors', 'classes', 'num', 'jitter', 'ignore_thresh', 'truth_thresh', 'random',
                 'stride_x', 'stride_y', 'weights_type', 'weights_normalization', 'scale_x_y', 'beta_nms', 'nms_kind',
                 'iou_loss', 'iou_normalizer', 'cls_normalizer', 'iou_thresh', 'probability']

    # 遍历检查每个模型的配置
    for x in mdefs[1:]:  # 0对应 net配置
        # 遍历每个配置字典中的key值
        for k in x:
            if k not in supported:
                raise ValueError("Unsupported fields:{} in cfg".format(k))

    return mdefs


# 解析 my_data.data 文件,用于train的时候找到数据集
def parse_data_cfg(path):   
    if not os.path.exists(path) and os.path.exists('data' + os.sep + path):
        path = 'data' + os.sep + path

    with open(path, 'r') as f:
        lines = f.readlines()

    options = dict()
    for line in lines:
        line = line.strip()
        if line == '' or line.startswith('#'):
            continue
        key, val = line.split('=')
        options[key.strip()] = val.strip()

    return options


# info = parse_model_cfg('../cfg/my_yolov3.cfg')        # 测试解析 cfg文件
# info_data = parse_data_cfg('../data/my_data.data')
# print(info_data)        # {'classes': '20', 'train': 'data/my_train_data.txt', 'valid': 'data/my_val_data.txt', 'names': 'data/my_data_label.names'}

猜你喜欢

转载自blog.csdn.net/qq_44886601/article/details/130785081