pytorch读取VOC数据集

简单介绍VOC数据集

首先介绍下VOC2007数据集(下图是VOC数据集格式,为了叙述方便,我这里只放了两张图像)
在这里插入图片描述
Main文件夹内的trainval.txt中的内容如下:存储了图像的名称不加后缀。

000009
000052

Annotations中存储的是标注文件,以xml文件存储。这里简单截个图说明一下:

<annotation>
   <folder>VOC2007</folder>
   <filename>000009.jpg</filename>               # !存储图像名称,若转换voc数据集,记住命名格式最好是相同
   <source>
      <database>The VOC2007 Database</database>
      <annotation>PASCAL VOC2007</annotation>
      <image>flickr</image>
      <flickrid>325443404</flickrid>
   </source>
   <owner>
      <flickrid>autox4u</flickrid>
      <name>Perry Aidelbaum</name>
   </owner>
   <size>                                          # 存储着图像的宽和高以及深度
      <width>500</width>                               
      <height>375</height>
      <depth>3</depth>
   </size>
   <segmented>0</segmented>
   <object>                                        # 存储着对象
      <name>horse</name>
      <pose>Right</pose>
      <truncated>0</truncated>
      <difficult>0</difficult>                   # 是否是困难检测目标
      <bndbox>                                     # 图像中的box
         <xmin>69</xmin>
         <ymin>172</ymin>
         <xmax>270</xmax>
         <ymax>330</ymax>
      </bndbox>
   </object>
</annotation>

pytorch读取脚本

然后你可以通过以下脚本进行读取–>(读取脚本在SSD中的方式也挺好,这里仅仅提供一种最mini的读取方式)。

import torch
import xml.etree.ElementTree as ET
import os
import cv2
import numpy as np
from torchvision import transforms

class VOCDataset(torch.utils.data.Dataset):

    CLASSES_NAME = (
        "__background__ ",                 # 记得加上背景类
        "pottedplant",
        "person",
        "horse",
    )
    # 初始化类
    def __init__(self, root_dir, resize_size=[800, 1024], split='trainval', use_difficult=False):

        self.root = root_dir
        self.use_difficult = use_difficult
        self.imgset = split

        self._annopath = os.path.join(self.root, "Annotations", "%s.xml")
        self._imgpath = os.path.join(self.root, "JPEGImages", "%s.jpg")
        self._imgsetpath = os.path.join(self.root, "ImageSets", "Main", "%s.txt")

        # 读取trainval.txt中内容
        with open(self._imgsetpath % self.imgset) as f:     # % 是python字符串中的一个转义字符可以百度下,不难
            self.img_ids = f.readlines()
        self.img_ids = [x.strip() for x in self.img_ids]    # ['000009', '000052']

        self.name2id = dict(zip(VOCDataset.CLASSES_NAME, range(len(VOCDataset.CLASSES_NAME))))
        self.resize_size = resize_size
        self.mean = [0.485, 0.456, 0.406]      # voc数据集中所有图像矩阵的均值和方差,为后续图像归一化做准备
        self.std = [0.229, 0.224, 0.225]
        print("INFO=====>voc dataset init finished  ! !")

    def __len__(self):
        return len(self.img_ids)

    def _read_img_rgb(self, path):
        return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)

    def __getitem__(self, index):

        img_id = self.img_ids[index]
        img = self._read_img_rgb(self._imgpath % img_id)

        anno  = ET.parse(self._annopath % img_id).getroot()  # 读取xml文档的根节点
        boxes = []
        classes = []

        for obj in anno.iter("object"):
            difficult = int(obj.find("difficult").text) == 1
            if not self.use_difficult and difficult:
                continue
            _box = obj.find("bndbox")
            box = [
                _box.find("xmin").text,
                _box.find("ymin").text,
                _box.find("xmax").text,
                _box.find("ymax").text,
            ]
            TO_REMOVE = 1                                  # 由于像素是网格存储,坐标2实质表示第一个像素格,所以-1
            box = tuple(
                map(lambda x: x - TO_REMOVE, list(map(float, box)))
            )
            boxes.append(box)

            name = obj.find("name").text.lower().strip()
            classes.append(self.name2id[name])             # 将类别映射回去

        boxes = np.array(boxes, dtype=np.float32)

        #将img,box和classes转成tensor
        img = transforms.ToTensor()(img)    # transforms 自动将 图像进行了归一化,
        boxes = torch.from_numpy(boxes)
        classes = torch.LongTensor(classes)

        return img, boxes, classes
if __name__ == '__main__':
    dataset = VOCDataset('E://Z_summary_net/Read_VOC/VOCdevkit/VOC2007/') # 实例化一个对象
    img,box,cls = dataset[0]          # 返回第一张图像及box和对应的类别
    print(img.shape)
    print(box)
    print(cls)

    # 这里简单做一下可视化
    # 由于opencv读入是矩阵,而img现在是tensor,因此,首先将tensor转成numpy.array
    img_ = (img.numpy()*255).astype(np.uint8).transpose(1,2,0)# 注意由于图像像素分布0-255,所以转成uint8
    print(img_.shape)
    cv2.imshow('test',img_)
    cv2.waitKey(0)

注意

这里有一个细节:torch将box、classes和img转换成张量tensor。假如现在我想利用OpenCV重新还原出图像来,那么存在一个问题:就是经过transforms.ToTensor()的img已经被进行了归一化,img的像素矩阵除以了255。因此,要想还原cv2中的numpy数组,需要*255。另外,由于cv2需要读取[W,H,C]的矩阵,因此,还需要通过transpose(1,2,0)交换一下维度。
最后,贴上transforms.ToTensor()的源码:

if isinstance(pic, np.ndarray):
    # handle numpy array
    if pic.ndim == 2:
        pic = pic[:, :, None]
    img = torch.from_numpy(pic.transpose((2, 0, 1)))     # 交换维度
    # backward compatibility
    if isinstance(img, torch.ByteTensor):                 # /255 归一化
        return img.float().div(255)
    else:
        return img

Guess you like

Origin blog.csdn.net/wulele2/article/details/109959299