yolov5-2.0 detect.py详细解析及注释

import argparse
import os
import platform
import sys
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

FILE = Path(__file__).resolve()    # 将路径或者路径段解析为绝对路径
ROOT = FILE.parents[0]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))    # Path.cwd()就是当前工作目录

from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
                           increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode

@torch.no_grad()  # 数据不需要计算梯度,也不会进行反向传播
def run (weights=ROOT / 'yolov5s.pt',    # 加载训练权重
         # 测试数据,可以是图片/视频路径,也可以是'0'(电脑自带摄像头),也可以使用rtsp等视频流
         source=ROOT / 'data/imsges',
         data=ROOT / 'data/coco128.yaml',  # dataset.yaml path
         # 网络的输入图片大小(高,宽)
         imgsz=(640,640),
         # 置信度阈值
         conf_thres=0.25,
         # 做nms的iou阈值,默认为0.45
         iou_thres=0.45,
         max_det=1000,  # 保留的最大检测框数量,每张图片中检测目标的个数最多为1000类
         device='',  # cuda驱动,0,1,2,3或者cpu
         view_img=False,  # 是否展示预测之后的图片
         save_txt=False,  # 是否将预测框坐标以txt形式保存,默认为False不保存
         save_conf=False,  # 是否将置信度保存到txt中,默认为Falsee不保存
         save_crop=False,  # 是否保存剪裁预测框图片,默认不保存
         nosave=False,  # 是否被推理的图片或者视频在run/detect/exp*中,默认不保存
         classes=None,  # 是否只保留某一类别,默认为保留全部,需要保留某一类别是将None改为类别号
         agnostic_nms=False,  # 进行NMS时是否去除不同类别之间的框,默认为不去除
         augment=False,  # 推理时是否进行多尺度、翻转(TTA)推理
         visualize=False,  # 是否可视化网络层输出特征
         update=False,  # 如果为Ture,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
         project=ROOT / 'runs/detect',   # 推理结果保存路径
         name='exp',   # 保存结果输出的文件名
         exist_ok=False,   # 是否重新创建日志文件,为False时重新创建文件
         line_thickness=3,   # 画框线条的粗细
         hide_labels=False,   # 可视化时是否隐藏标签
         hide_conf=False,    # 可视化时是否隐藏置信度
         half=False,    # 是否使用半精度推理(F16),使用半精度可以提高检测速度
         dnn=False,    # 是否用OpenCV DNN进行预测
         vid_stride=1,    # 视频输入的帧率
):
    source = str(source)    # 将输入路径变为字符串
    save_img = not nosave and not source.endswith('.txt')    # 是否保存图片和txt文件
    # 判断文件是否是视频流  *.isnumeric路径是否有数字组成,返回Ture或者False,
    #                   *.lower()转化为小写
    #                   *.endswith()判断是否以制定的字符或者字符串结尾
    #                   *.startswith()判断当前字符串是否以另外一个给定的字符串开头的,根据判断结果返回true或false
    is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
    is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
    webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
    screenshot = source.lower().startswith('screen')
    if is_url and is_file:
        source = check_file(source)  # download


    # 预测路径是否存在,不存在就新建,按照实验文件以递增的形式新建
    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)
    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)

    device = select_device(device)    # 获取设备
    # 调用common.py中的DetectMultiBackend,进行推理。加载模型权重,设备,数据以及是否使用半精度、和dnn
    models = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
    stride, names, pt = models.stride, models.names, models.pt
    # 确保输入图片的尺寸imgsz能整除32(strde=32)如果不能则调整为能被整除,并返回。
    imgsz = check_img_size(imgsz, s=stride)

    bs = 1
    if webcam:
        view_img = check_imshow(warn=True)
        # 调用dataloaders.py中的LoadStreams类( 流加载器 )加载视频流,帧率为vid_stride(初始化时自己设定)
        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
        # 计算dataset的长度
        bs = len(dataset)
    elif screenshot:
        # 调用dataloders.py中LoadScreenshots类(截图数据加载器)
        dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
    else:
        # 直接从source文件夹下读取图片
        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
    # 保存路径
    vid_path, vid_writer = [None] * bs, [None] * bs

    # 推理前测试  进行一次前向推理,测试模型是否可以跑通
    models.warmup(imgsz=(1 if pt or models.triton else bs, 3, *imgsz))
    seen, windows, dt = 0, [], (Profile(), Profile(), Profile())

    # 正式推理 从上面的LoadImages()可以看到每次只输入单张图片
    # 处理每一张图片或者视频的格式
    for path, im, im0s, vid_cap, s in dataset:
        # path 是图片或视频的路径
        # im 是进行resize + pad之后的图片
        # im0s 是原尺寸图片
        # vid_cap 当读取的图片为None时,读取视频为视频源
        with dt[0]:
            # torch.from_numpy()的作用是将生成的数组转换为张量。
            im = torch.from_numpy(im).to(models.device)
            # 图片设置为FP16/FP32
            im = im.half() if models.fp16 else im.float()
            im /= 255  # 归一化
            # 如果没有batch_size的话则在最前面添加一个轴
            if len(im.shape) == 3:
                # 增加一个维度
                im = im[None]

        with dt[1]:
            # 可视化文件路径 调用utils/general.py文件中的increment_path类增加文件或目录路径
            visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
            # augment推理时是否进行多尺度、翻转(TTA)推理
            pred = models(im, augment=augment, visualize=visualize)


        # 非极大值抑制 NMS
        '''非最大抑制应用于“减薄”边缘。应用梯度计算后,从梯度值中提取的边缘仍然非常模糊。
        关于标准3,应该只对边缘有一个准确的响应。
        因此,非最大抑制可以帮助抑制除局部最大值之外的所有梯度值(通过将它们设置为0),
        其指示具有最强烈的强度值变化的位置。渐变图像中每个像素的算法是:
        1、将当前像素的边缘强度与正梯度方向和负梯度方向上的像素的边缘强度进行比较。
        2、如果当前像素的边缘强度与具有相同方向的掩模中的其他像素相比是最大的
        (即,指向y方向的像素,则将其与其上方和下方的像素进行比较,垂直轴),该值将被保留。否则,该值将被抑制。
        
        例如在行人检测中,滑动窗口经提取特征,经分类器分类识别后,每个窗口都会得到一个分数。
        但是滑动窗口会导致很多窗口与其他窗口存在包含或者大部分交叉的情况。
        这时就需要用到NMS来选取那些邻域里分数最高(是行人的概率最大),并且抑制那些分数低的窗口。'''

        with dt[2]:
            # 调用utils/general.py文件中的non_max_suppression类进行非极大值抑制
            pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
            # pred 是网络的输出结果
            # conf_thres 是置信度阈值
            # iou_thres 是iou阈值
            # classes 是否只保留指定类别
            # 进行nms是否也去除不同类别之间的框
            # max-det 保留的检测框最大数量


        # 后续保存或者打印预测信息
        # 对每张图片进行处理 将pred映射回原图
        for i, det in enumerate(pred):
            seen += 1
            if webcam:
                # 如果输入源是webcam则batch_size>=1 取出dataset中的一张图片
                p, im0, frame = path[i], im0s[i].copy(), dataset.count
                s += f'{i}: '
            else:
                # 但是大部分情况都是从LoadImages流读取本地文件中的照片或者视频 所以batch_size=1
                # p:当前图片/视频的绝对路径
                # im0: 原始图片 (letterbox + pad 之前的图片)
                # frame: 视频流
                p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)

            p = Path(p)   # 当前路径(yolov5/data/images)
            save_path = str(save_dir / p.name)    # 图片或视频的保存路径(如runs/detect/exp/***.jpg)
            # 设置保存框的坐标的txt文件路径,每张图片对应一个框坐标信息
            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')
            # 设置打印图片的信息
            s += '%gx%g ' % im.shape[2:]
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # 归一化增益
            imc = im0.copy() if save_crop else im0   # 保存截图
            annotator = Annotator(im0, line_width=line_thickness, example=str(names))
            if len(det):
                det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()    #将预测信息映射到原图

                # 打印检测到的类别数量
                for c in det[:, 5].unique():
                    n = (det[:, 5] == c).sum()
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)},"

                # 保存结果: txt/图片画框/crop-image
                for *xyxy, conf, cls in reversed(det):
                    # 将每个图片的预测信息分别存入save_dir/labels下的xxx.txt中 每行:class_id + score + xywh
                    if save_txt:
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1,4)) / gn).view(-1).tolist()
                        line = (cls, *xywh, conf) if save_conf else (cls, *xywh)
                        with open(f'{txt_path}.txt', 'a') as f:
                            # .rstrip 去除后面的空白字符串
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')

                    # 在原图上画框,将预测到的目标剪切出来 保存成图片 保存在save_dir/crops下 在原图像上画框或者保存结果
                    if save_img or save_crop or save_img:   # 在原图上画框
                        c = int(cls)    # 整数类
                        label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
                        annotator.box_label(xyxy, label, color=colors(c, True))
                    if save_crop:
                        # 在原图上画框 + 将预测到的目标剪切出来 保存成图片 保存在save_dir/crops下
                        save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)


            im0 = annotator.result()
            if view_img:
                if platform.system() == 'Linux' and p not in windows:
                    windows.append(p)
                    cv2.nameWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
                    cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
                cv2.imshow(str(p), im0)
                cv2.waitKey(1)

            # 保存图片
            if save_img:
                if dataset.mode == 'image':
                    cv2.imwrite(save_path, im0)
                else:
                    if vid_path[i] != save_path:
                        vid_path[i] = save_path
                        if isinstance(vid_writer[i], cv2.VideoWriter):
                            vid_writer[i].release()
                        if vid_cap:
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        else:
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
                        save_path = str(Path(save_path).with_suffix('.mp4'))  # 强制输出结果为.mp4后缀
                        vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w,h))
                    vid_writer[i].write(im0)

        # 打印时间
        LOGGER.info(f"{s}{''if len(det) else '(no detection),'}{dt[1].dt * 1E3:.1f}ms")

    # 打印每张图片的检测速度
    if save_txt or save_img:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
    if update:
        strip_optimizer(weights[0])


def parse_opt():
    parser = argparse.ArgumentParser()     # 创建一个解析对象
    parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path or triton URL')
    '''
    nargs:ArgumentParser对象通常将一个动作与一个命令行参数关联。nargs关键字参数将一个动作与不同数目的命令行参数关联在一起
    nargs=N,一个选项后可以跟多个参数,参数的个数必须为N的值,这些参数会生成一个类表,当nargs=1时,会生成一个长度为1的列表
    nargs=?,如果没有在命令行出现对应的项,则给对应的项赋值为default。特殊的是,对于可选项,如果命令行出现了此可选项,但是之后没有跟赋值参数
            则此时给此可选项赋cons的值,而不是赋default的值。
    nargs=*,和N类似,但是没有规定列表长度
    nargs=+,和*类似,但是给对应的项没有传入参数时,会报错error:too few arguments。
    nargs=argparse.REMAINDER,所有剩余的参数,均转化为一个列表赋值给此项,通常用此方法将剩余的参数传入另一个parser进行解析。如果nargs没有
            定义,则可传入的参数的数量由action决定,通常情况下为一个,并且不会生成长度为1的列表。
    '''
    parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
    parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
    parser.add_argument('--max-det', type=int, default=1000, help='maximum detection per image')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or CPU')
    parser.add_argument('--view-img', action='store_true', help='show results')
    # action='store_true' 运行代码时在命令行加--view-img时,view-img为Ture,不加--view-img时,--view-img为False
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
    parser.add_argument('--save-crop', action='store_true', help='save cropped prodiction boxes')
    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--visualize', action='store_true', help='visuslize features')
    parser.add_argument('--update', action='store_true', help='update all models')
    parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
    parser.add_argument('--name', default='exp', help='save results to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
    parser.add_argument('--hide-labels', default=False,  action='store_true', help='hide labels')
    parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidence')
    parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
    parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
    parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
    opt = parser.parse_args()     # 解析传入参数
    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1    # 扩充维度,如果是一位就扩充一位
    print_args(vars(opt))
    return opt

def main(opt):
    # 检查环境、打印参数,主要是requrement.txt的包是否安装,用彩色显示设置参数
    check_requirements(exclude=('tensorboard', 'thop'))
    # 执行run()函数
    run(**vars(opt))


if __name__ == "__main__":
    opt = parse_opt()
    main(opt)

猜你喜欢

转载自blog.csdn.net/weixin_45994963/article/details/128110819
今日推荐