pytorch 实现yolo3详细理解(三) 数据集处理

本章详细讲解数据的处理问题,将coco数据集读取,以及之后自定义数据集的处理,

数据预处理思想

yolo3的数据集处理也是一大亮点,由于yolo3对数据集的输入有要求,指定的照片输入大小必须是416,所有对于不满足照片的大小有一系列的操作,如果直接resize操作,将直接损失照片信息,网络在学习分类的过程还要适应照片尺寸的问题,导致训练效果不佳,在yolo3中是先进行高和宽的调整一样大,在进行上采样的resize,同时要修改label的坐标位置,随机水平翻转,再一次随机变化大小,之后再变化到416的大小尺寸作为输入。

代码

class ListDataset(Dataset):  #继承Dataset
    def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
        with open(list_path, "r") as file:
            self.img_files = file.readlines()

        self.label_files = [
            path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt") #这一步是生成labels的位置
            for path in self.img_files
        ]
        self.img_size = img_size
        self.max_objects = 100
        self.augment = augment
        self.multiscale = multiscale
        self.normalized_labels = normalized_labels
        self.min_size = self.img_size - 3 * 32
        self.max_size = self.img_size + 3 * 32
        self.batch_count = 0

    def __getitem__(self, index):

        # ---------
        #  Image
        # ---------

        img_path = self.img_files[index % len(self.img_files)].rstrip()   #按照索引的方式找到对应的路径

        # Extract image as PyTorch tensor
        img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))  #读取照片

        # Handle images with less than three channels
        if len(img.shape) != 3:
            img = img.unsqueeze(0)
            img = img.expand((3, img.shape[1:]))

        _, h, w = img.shape
        h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)   #直接理解为照片的宽度和高度
        # Pad to square resolution
        img, pad = pad_to_square(img, 0)   #这一步就是将高和宽变成一样大小
        _, padded_h, padded_w = img.shape

        # ---------
        #  Label
        # ---------

        label_path = self.label_files[index % len(self.img_files)].rstrip()   #照片对应的label路径

        targets = None
        if os.path.exists(label_path):
            boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
            # Extract coordinates for unpadded + unscaled image
            x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)    #label的坐标点位置是xywh所以先进行转化
            y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
            x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
            y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
            # Adjust for added padding
            x1 += pad[0]  #照片大小变化了所以框的坐标点需要修改
            y1 += pad[2]
            x2 += pad[1]
            y2 += pad[3]
            # Returns (x, y, w, h)
            boxes[:, 1] = ((x1 + x2) / 2) / padded_w    #在次重新转化xywh形式
            boxes[:, 2] = ((y1 + y2) / 2) / padded_h
            boxes[:, 3] *= w_factor / padded_w
            boxes[:, 4] *= h_factor / padded_h

            targets = torch.zeros((len(boxes), 6))
            targets[:, 1:] = boxes

        # Apply augmentations
        if self.augment:
            if np.random.random() < 0.5:
                img, targets = horisontal_flip(img, targets)   #随机水平翻转

        return img_path, img, targets

    def collate_fn(self, batch):    #自定义类中的函数用于batch处理
        paths, imgs, targets = list(zip(*batch))   #可以理解写这个函数必须写这个操作,就是将 __getitem__的输出作为列表,
        # Remove empty placeholder targets
        targets = [boxes for boxes in targets if boxes is not None]
        # Add sample index to targets
        for i, boxes in enumerate(targets):
            boxes[:, 0] = i
        targets = torch.cat(targets, 0)   #增加一个维度,就可以是批次额处理
        # Selects new image size every tenth batch
        if self.multiscale and self.batch_count % 10 == 0:
            self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))  #对照片随机变大变小
        # Resize images to input shape
        imgs = torch.stack([resize(img, self.img_size) for img in imgs])  #在一次将照片大小变化为原来的416
        self.batch_count += 1
        return paths, imgs, targets

    def __len__(self):
        return len(self.img_files)

这一步将数据读取封装成一个类,中间还有一起其他的函数,

def pad_to_square(img, pad_value):
    c, h, w = img.shape
    dim_diff = np.abs(h - w)
    # (upper / left) padding and (lower / right) padding
    pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
    # Determine padding
    pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
    # Add padding
    img = F.pad(img, pad, "constant", value=pad_value)

    return img, pad

这一步就是将高和宽整成一样大小,比如500,300输出就是500,500的大小,用pad1,pad2记录是高或者宽拉长了多少,用于框的位置修改。

def resize(image, size):
    image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
    return image

通过上采样的形式就修改了照片的尺寸,比直接进行resize的效果要好

def random_resize(images, min_size=288, max_size=448):
    new_size = random.sample(list(range(min_size, max_size + 1, 32)), 1)[0]
    images = F.interpolate(images, size=new_size, mode="nearest")
    return images

这一步是将大小随机的变化,大小变化设置了一定范围

def horisontal_flip(images, targets):
    images = torch.flip(images, [-1])
    targets[:, 2] = 1 - targets[:, 2]
    return images, targets

进行水平翻转的代码

其中有一个重点是__getitem__的输出,理解输出值得形式,可以自己从新再写一个dataset类的读取,官方给出的代码很一般,后面我会自己写一个csv文件的读取
看下输出值,以一张照片为例


print(img_path)
print(img.shape)
print(targets)
/images/train2014/COCO_train2014_000000000009.jpg
torch.Size([3, 768, 768])
tensor([[ 0.0000, 45.0000,  0.4795,  0.6416,  0.9556,  0.4466],
        [ 0.0000, 45.0000,  0.7365,  0.3104,  0.4989,  0.3573],
        [ 0.0000, 50.0000,  0.6371,  0.6747,  0.4941,  0.3829],
        [ 0.0000, 45.0000,  0.3394,  0.4392,  0.6789,  0.5861],
        [ 0.0000, 49.0000,  0.6468,  0.2244,  0.1180,  0.0727],
        [ 0.0000, 49.0000,  0.7731,  0.2224,  0.0907,  0.0729],
        [ 0.0000, 49.0000,  0.6683,  0.2952,  0.1313,  0.1102],
        [ 0.0000, 49.0000,  0.6429,  0.1844,  0.1481,  0.1110]])

先看下yolo3训练时的数据集在文件夹内的放置
在这里插入图片描述
一个训练集下的数据集信息,images是照片,labels是每张照片对应的label信息,
classes是全部分类的名称,train保存训练图片的路径,valid是测试照片的路径
在这里插入图片描述

只有一张,照片名字对应label的名字
在这里插入图片描述这个是train的label 一张照片有一个txt文件保存信息,一个txt可能包含多种框,这种事为了读取一张照片就将所有的框作为处理,
在这里插入图片描述
label内的信息,储存方式为xywh方式,坐标点的位置进行归一化了
在这里插入图片描述
这个是训练的txt,用来保存全部需要训练的照片路径,通过读取这一个文件来加载照片

原创文章 25 获赞 35 访问量 5197

猜你喜欢

转载自blog.csdn.net/cp1314971/article/details/104993542