ultralytics的YOLOv8改为自用版本

由于需要用pyqt给yolov8做一个界面,而ultralytics一层嵌一层,不是很好用,所以对它的这个源码进行精简,具体代码我放到了这里,ultralytics使用的版本是8.0.54。

 具体代码如下,需要根据自己的情况来修改data的配置文件以及权值文件,在代码的49和50行

import torch
import cv2
import numpy as np

from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
from ultralytics.nn.autobackend import AutoBackend


def get_annotator(img):
    return Annotator(img, line_width=3, example=str(model.names))

def preprocess(img):
    img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(model.device)
    img = img.float()
    img /= 255  # 0 - 255 to 0.0 - 1.0
    return img

def postprocess(preds, img, orig_imgs):
    preds = ops.non_max_suppression(preds,
                                    conf,
                                    iou,
                                    agnostic=False,
                                    max_det=300,
                                    classes=None)

    results = []
    for i, pred in enumerate(preds):
        orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs # 返回的orig_img变量表示原始图像
        if not isinstance(orig_imgs, torch.Tensor):
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
        path = 0
        img_path = path[i] if isinstance(path, list) else path
        results.append(Results(orig_img=orig_img, path=img_path, names=model.names, boxes=pred))
    return results


def save_preds(vid_cap, im0):
    fps = int(vid_cap.get(cv2.CAP_PROP_FPS))  # integer required, floats produce error in MP4 codec
    w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    vid_writer = cv2.VideoWriter('1.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
    im0 = (im0 * 255).astype(np.uint8)  # Convert to byte type
    vid_writer.write(im0)


model = ''
data = ''
imgsz = 640
visualize = False
conf = 0.25
iou = 0.5
model = AutoBackend(model,
                     device=torch.device('cuda:0'),
                     data=data,
                     verbose=True)
device = torch.device('cuda:0')
model.eval()
cam = cv2.VideoCapture(0)

while 1:
    im0 = cam.read()[1]

    im = np.stack([LetterBox(imgsz, True, stride=32)(image=im0)])
    im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW
    im = np.ascontiguousarray(im)  # contiguous
    dt = (ops.Profile(), ops.Profile(), ops.Profile())

    # preprocess
    with dt[0]:
        im = preprocess(im)
        if len(im.shape) == 3:
            im = im[None]  # expand for batch dim

    # inference
    with dt[1]:
        preds = model(im, augment=False, visualize=visualize)

    # postprocess
    with dt[2]:
        results = postprocess(preds, im, im0)

    det = results[0].boxes  # TODO: make boxes inherit from tensors

    # write
    for d in reversed(det):
        cls, conf, id = d.cls.squeeze(), d.conf.squeeze(), None if d.id is None else int(d.id.item())

        c = int(cls)  # integer class
        name = ('' if id is None else f'id:{id} ') + model.names[c]
        label = f'{name} {conf:.2f}'
        p1, p2 = (int(d.xyxy.squeeze()[0]), int(d.xyxy.squeeze()[1])), (int(d.xyxy.squeeze()[2]), int(d.xyxy.squeeze()[3]))
        lw = max(round(sum(im.shape) / 2 * 0.003), 2)
        cv2.rectangle(im0, p1, p2, colors(c, True), thickness=lw, lineType=cv2.LINE_AA)
        if label:
            tf = max(lw - 1, 1)  # font thickness
            w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0]  # text width, height
            outside = p1[1] - h >= 3
            p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
            # cv2.rectangle(im0, p1, p2, (128, 128, 128), -1, cv2.LINE_AA)  # filled
            cv2.putText(im0,
                        label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
                        0,
                        lw / 3,
                        (0, 0, 255),
                        thickness=tf,
                        lineType=cv2.LINE_AA)


    cv2.imshow("result", im0)
    # save_preds(cam, im0)
    print('preprocess:{},inference:{},postprocess:{}'.format(dt[0].dt * 1E3 ,dt[1].dt * 1E3 ,dt[2].dt * 1E3 ))
    if cv2.waitKey(1) & 0xff == ord('q'):  # 1 millisecond
        break

cam.release()
cv2.destroyAllWindows()

猜你喜欢

转载自blog.csdn.net/w1036427372/article/details/130784437