pytrorch学习(二十三)—torchvision加载VOC数据集

11478104-6ae7a34f16bf856f.png
image.png
11478104-2111aa1ed641df55.png
image.png
11478104-7a6ce52f0fe4acc0.png
image.png
11478104-1615d6b5273023d8.png
image.png
11478104-bdd10bace0f1f84a.png
image.png
11478104-cf78960ea0ece9d8.png
image.png
11478104-954dc76b85a87472.png
image.png
11478104-66da0e90cf7a01be.png
image.png
  • daima
import torchvision.datasets as datasets
import torchvision
import torch
import numpy as np
import cv2


def show_object_rect(image: np.ndarray, bndbox):
    pt1 = bndbox[:2]
    pt2 = bndbox[2:]
    image_show = image
    return cv2.rectangle(image_show, pt1, pt2, (0,255,255), 2)


def show_object_name(image: np.ndarray, name: str, p_tl):
    return cv2.putText(image, name, p_tl, 1, 1, (255, 0, 0))


voc_trainset = datasets.VOCDetection('/media/weipenghui/Extra/VOC/VOC_Detection_2012',year='2012', image_set='train', download=False)

print('-'*40)
print('VOC2012-trainval')
print(len(voc_trainset))

for i, sample in enumerate(voc_trainset, 1):
    image, annotation = sample[0], sample[1]['annotation']
    objects = annotation['object']
    show_image = np.array(image)
    print('{} object:{}'.format(i, len(objects)))
    if not isinstance(objects,list):
        object_name = objects['name']
        object_bndbox = objects['bndbox']
        x_min = int(object_bndbox['xmin'])
        y_min = int(object_bndbox['ymin'])
        x_max = int(object_bndbox['xmax'])
        y_max = int(object_bndbox['ymax'])
        show_image = show_object_rect(show_image, (x_min, y_min, x_max, y_max))
        show_image =show_object_name(show_image, object_name, (x_min, y_min))
    else:
        for j in objects:
            object_name = j['name']
            object_bndbox = j['bndbox']
            x_min = int(object_bndbox['xmin'])
            y_min = int(object_bndbox['ymin'])
            x_max = int(object_bndbox['xmax'])
            y_max = int(object_bndbox['ymax'])
            show_image = show_object_rect(show_image, (x_min, y_min, x_max, y_max))
            show_image = show_object_name(show_image, object_name, (x_min, y_min))

    cv2.imshow('image', show_image)
    cv2.waitKey(0)


print(voc_trainset)
print('Down load ok')

猜你喜欢

转载自blog.csdn.net/weixin_34416649/article/details/87130580
今日推荐