Comer yolov8 bocado a bocado (2)

La primera parte de la capacitación se presentó anteriormente, que también es un artículo que la mayoría de la gente puede encontrar en Internet, pero las últimas dos partes deben ser información que no está disponible en Internet. Espero que les guste.

0. datos

Mis datos son algunos palés, principalmente para detectar los agujeros en los palés, e identificar la posición y el ángulo de deflexión de los agujeros. La imagen original es la siguiente
inserte la descripción de la imagen aquí
Mi anotación
inserte la descripción de la imagen aquí
Utilizo labelme para anotar y luego convertir al formato yolo, el código de conversión es el siguiente.

# coding=utf-8
import os
import sys

path = os.path.dirname(__file__)
sys.path.append(path)

'''
Author:Don
date:2022/8/3 11:49
desc:
'''
import os
import json
import glob
#输入口,就是你图片和json存放的那个文件,输出的txt也在这个文件夹里
labelme_dir=r"E:\2022\work\shchaiduo\image"


def get_labelme_data(labelme_dir):

    with open(labelme_dir) as f:
        j=json.load(f)
        out_data=[]
        img_h =j["imageHeight"]
        img_w =j["imageWidth"]
        for shape in j["shapes"]:
            label=shape["label"]
            points=shape["points"]
            x,y,x2,y2=points[0][0],points[0][1],points[1][0],points[1][1]
            x_c=(x+x2)//2
            y_c=(y+y2)//2
            w=abs(x-x2)
            h=abs(y-y2)
            out_data.append([label,x_c,y_c,w,h])
    return img_h,img_w,out_data

def rename_Suffix(in_,mode=".txt"):
    in_=in_.split('.')
    return  in_[0]+mode

def make_yolo_data(in_dir):
    json_list=glob.glob(os.path.join(in_dir,'*.json'))

    for json_ in json_list:
        json_path=os.path.join(in_dir,json_)
        json_txt=rename_Suffix(json_)
        img_h,img_w,labelme_datas=get_labelme_data(json_path)
        with open(os.path.join(in_dir,json_txt),'w+') as f:
            for labelme_data in labelme_datas:
                label=labelme_data[0]
                x_c=labelme_data[1]/img_w
                y_c=labelme_data[2]/img_h
                w=labelme_data[3]/img_w
                h=labelme_data[4]/img_h
                f.write("{} {} {} {} {}\n".format(label,x_c,y_c,w,h))
            f.close()


if __name__ == '__main__':
    make_yolo_data(labelme_dir)




inserte la descripción de la imagen aquí
las imágenes son imágenes
inserte la descripción de la imagen aquí

etiquetas es el formato txt de la etiqueta,
inserte la descripción de la imagen aquí
concretamente la siguiente figura, 0 es el identificador de la etiqueta, porque solo hay una clase, entonces la primera en mis datos es 0, y la segunda es la xywh correspondiente al agujero, pero debe ser dividido por el largo y el ancho de la imagen, específicamente Ver el código de conversión de etiquetas arriba. Debido a que solo hay 2 agujeros en una paleta, mi txt solo tiene 2 conjuntos de datos.
inserte la descripción de la imagen aquí

la prueba es una imagen
inserte la descripción de la imagen aquí

1. Preparación de datos antes del entrenamiento

Debido a que mis datos se recopilan en el sitio, no se necesitan muchas técnicas de mejora de datos (comprensión personal). En la industria, lo más importante es la seguridad más que la precisión. Significa que si es correcto, es 100%, y si es incorrecto, es 0%, lo mejor es que no haya detección falsa, y la detección perdida es aceptable. Por lo tanto, el modelo no sugiere una mejor capacidad de generalización. Es mejor llamar directamente a la policía por cosas que no se han visto, en lugar de dar un rango de detección aproximado. Así que solo usé la API de aLbumentations en v8 y eliminé todo lo demás. El tamaño de lote predeterminado = 1.
inserte la descripción de la imagen aquí

from pathlib import Path
import glob
import os
from torch.utils.data import Dataset
from tqdm import tqdm
from multiprocessing.pool import ThreadPool
from PIL import Image, ImageOps
import random
import albumentations as A
import numpy as np
import torch

NUM_THREADS = min(8, max(1, os.cpu_count() - 1))  # number of YOLOv5 multiprocessing threads
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}'  # tqdm bar format
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"  # include image suffixes


class Albumentations:
    # YOLOv8 Albumentations class (optional, only used if package is installed)
    def __init__(self, p=1.0):
        self.p = p
        T = [
            A.Blur(p=0.01),
            A.MedianBlur(p=0.01),
            A.ToGray(p=0.01),
            A.CLAHE(p=0.01),
            A.RandomBrightnessContrast(p=0.0),
            A.RandomGamma(p=0.0),
            A.ImageCompression(quality_lower=75, p=0.0), ]  # transforms
        self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))

    def __call__(self, labels):
        im = labels["img"]
        cls = labels["cls"]
        if len(cls):
            if self.transform and random.random() < self.p:
                new = self.transform(image=im, bboxes=labels["bboxes"], class_labels=cls)  # transformed
                labels["img"] = self._format_img(new["image"])
                labels["cls"] = torch.tensor(new["class_labels"])
                labels["bboxes"] = torch.tensor(new["bboxes"])
                labels["batch_idx"] = torch.zeros(labels["cls"].shape[0])
        return labels

    def _format_img(self, img):
        if len(img.shape) < 3:
            img = np.expand_dims(img, -1)
        img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1]).astype(np.float32)
        img = torch.from_numpy(img)
        return img


# 读取数据集存储
def verify_image_label(args):
    im_file, lb_file = args
    try:
        im = Image.open(im_file)
        im.verify()  # PIL verify
        shape = im.size  # image size
        shape = (shape[1], shape[0])  # hw
        if im.format.lower() in ("jpg", "jpeg"):
            with open(im_file, "rb") as f:
                f.seek(-2, 2)
                if f.read() != b"\xff\xd9":  # corrupt JPEG
                    ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
        # verify labels
        if os.path.isfile(lb_file):
            with open(lb_file) as f:
                lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
                lb = np.array(lb, dtype=np.float32)
            nl = len(lb)
            if nl:
                _, i = np.unique(lb, axis=0, return_index=True)
                if len(i) < nl:  # duplicate row check
                    lb = lb[i]  # remove duplicates
            else:
                lb = np.zeros((0, 5), dtype=np.float32)
        else:
            lb = np.zeros((0, 5), dtype=np.float32)
        lb = lb[:, :5]
        return im_file, lb, shape
    except Exception as e:
        return [None, None, None]


class YOLODataset(Dataset):

	def __init__(self, img_path, imgsz=640, augment=True):
        super(YOLODataset, self).__init__()
        self.img_path = img_path
        self.imgsz = imgsz
        self.augment = augment
        self.im_files = self.get_img_files(self.img_path)  # 读取图片
        self.labels = self.get_labels()  # 读取label
        self.ni = len(self.labels)
        # transforms
        self.transforms = Albumentations(p=1.0)

    def get_img_files(self, img_path):
        """Read image files."""
        try:
            f = []  # image files
            for p in img_path if isinstance(img_path, list) else [img_path]:
                p = Path(p)  # os-agnostic
                if p.is_dir():  # dir
                    f += glob.glob(str(p / "**" / "*.*"), recursive=True)
                elif p.is_file():  # file
                    with open(p) as t:
                        t = t.read().strip().splitlines()
                        parent = str(p.parent) + os.sep
                        f += [x.replace("./", parent) if x.startswith("./") else x for x in t]  # local to global path
            im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
        except Exception as e:
            raise FileNotFoundError(f"Error loading data from") from e
        return im_files

    def img2label_paths(self, img_paths):
        # Define label paths as a function of image paths
        sa, sb = f"{
      
      os.sep}images{
      
      os.sep}", f"{
      
      os.sep}labels{
      
      os.sep}"  # /images/, /labels/ substrings
        return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]

    def get_labels(self):
        self.label_files = self.img2label_paths(self.im_files)
        cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
        try:
            cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True  # load dict
        except (FileNotFoundError, AssertionError, AttributeError):
            cache, exists = self.cache_labels(cache_path), False  # run cache ops
        return cache["labels"]

    def cache_labels(self, path=Path("./labels.cache")):
        # Cache dataset labels, check images and read shapes
        if path.exists():
            path.unlink()  # remove *.cache file if exists
        x = {
    
    "labels": []}
        desc = f"Scanning {
      
      path.parent / path.stem}..."
        total = len(self.im_files)
        with ThreadPool(NUM_THREADS) as pool:
            results = pool.imap(func=verify_image_label,
                                iterable=zip(self.im_files, self.label_files))  # im_file, lb, shape
            pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
            for im_file, lb, shape, in pbar:
                if im_file:
                    x["labels"].append(
                        dict(
                            im_file=im_file,
                            shape=shape,
                            cls=lb[:, 0:1],  # n, 1
                            bboxes=lb[:, 1:],  # n, 4
                            segments=None,
                            keypoints=None,
                            normalized=True,
                            bbox_format="xywh"))
            pbar.close()
        np.save(str(path), x)  # save cache for next time
        return x

2. Toma datos durante el entrenamiento

Para obtener datos, es necesario implementar las funciones len y getitem, porque se utiliza el conjunto de datos de torch. Como queremos reescribir el índice, reescribimos la función collate_fn

inserte la descripción de la imagen aquí

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        return self.transforms(self.get_label_info(index))

    def get_label_info(self, index):
        label = self.labels[index].copy()
        label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
        return label

    def load_image(self, i):
        # Loads 1 image from dataset index 'i', returns (im, resized hw)
        f = self.im_files[i]
        im = cv2.imread(f)  # BGR
        if im is None:
            raise FileNotFoundError(f"Image Not Found {
      
      f}")
        h0, w0 = im.shape[:2]  # orig hw
        r = self.imgsz / max(h0, w0)  # ratio
        if r != 1:  # if sizes are not equal
            interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
            im = cv2.resize(im, (640, 512), interpolation=interp)
        return im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
     
    @staticmethod
    def collate_fn(batch):
        new_batch = {
    
    }
        keys = batch[0].keys()
        values = list(zip(*[list(b.values()) for b in batch]))
        for i, k in enumerate(keys):
            value = values[i]
            if k == "img":
                value = torch.stack(value, 0)
            if k in ["bboxes", "cls"]:
                value = torch.cat(value, 0)
            new_batch[k] = value
        new_batch["batch_idx"] = list(new_batch["batch_idx"])
        for i in range(len(new_batch["batch_idx"])):
            new_batch["batch_idx"][i] += i  # add target image index for build_targets()
        new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
        return new_batch

3. Integra datos

def seed_worker(worker_id):
    # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}'  # tqdm bar format
img_path = "../datasets/kongdong/images"
dataset = YOLODataset(img_path=img_path, imgsz=640,  augment=True)
RANK = int(os.getenv('RANK', -1))
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true"
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True,
                          pin_memory=PIN_MEMORY,
                          collate_fn=getattr(dataset, "collate_fn", None),
                          worker_init_fn=seed_worker,
                          generator=generator)
pbar = tqdm(enumerate(train_loader), total=1, bar_format=TQDM_BAR_FORMAT)
for i, batch in pbar:

Tomamos el conjunto de datos en un bucle for, entonces, ¿qué hay en el lote? vamos a ver
inserte la descripción de la imagen aquí

Ahora vamos a comprobar si los datos son correctos después de la transformación.

#  检测输入的数据图像对不对
def check_data(batch):
    img = batch["img"]
    labels = batch['bboxes']  # xywh
    labels[:, 0] *= 640
    labels[:, 1] *= 512
    labels[:, 2] *= 640
    labels[:, 3] *= 512
    input_tensor = img.squeeze()
    # 从[0,1]转化为[0,255],再从CHW转为HWC,最后转为cv2
    input_tensor = input_tensor.permute(1, 2, 0).type(torch.uint8).numpy()
    # RGB转BRG
    input_tensor = cv2.cvtColor(input_tensor, cv2.COLOR_RGB2BGR)
    for box in labels.int():  # xywh
        cv2.rectangle(input_tensor, (int(box[0] - box[2] / 2), int(box[1] - box[3] / 2)),
                      (int(box[0] + box[2] / 2), int(box[1] + box[3] / 2)), (255, 0, 255), -1)
    cv2.imshow('img', input_tensor)
    cv2.waitKey(0)

for i, batch in pbar:
    # Forward
    with torch.cuda.amp.autocast(False):
        check_data(batch)
        img = batch["img"]
        preds = model(img)

ok, así es,
inserte la descripción de la imagen aquí
veamos si la salida del modelo es correcta
inserte la descripción de la imagen aquí
y tiene el mismo tamaño que la salida de la red de inferencia directa en nuestro primer artículo.

Supongo que te gusta

Origin blog.csdn.net/qq_33228039/article/details/129672141
Recomendado
Clasificación