from torch.utils import data
import os
import cv2
class datasest(data.Dataset):
def __init__(self,path1,path2):
self.img_path = path1
self.label_path = path2
def __len__(self):
return len(os.listdir(self.img_path))
def __getitem__(self, index):
file = open(self.label_path,'r')
f = file.readlines()
labels = [fi.replace('\n','').replace(' ',' ').replace(' ',' ').replace(' ',' ').replace(' ',' ').split(' ') for fi in f][2:]
img = cv2.imread(os.path.join(self.img_path,labels[index][0]))
return img,labels[index]
data = datasest(r'D:\CelebA\img_celeba',r'D:\CelebA\list_bbox_celeba.txt')
for i in range(data.__len__()):
x1 = int(data[i][1][1])
y1 = int(data[i][1][2])
w = int(data[i][1][3])
h = int(data[i][1][4])
x2 = x1 + w
y2 = y1 + h
img = data[i][0]
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.imshow('i',img)
cv2.waitKey(1)
pytorvhDataSet获取自己的数据集
猜你喜欢
转载自blog.csdn.net/weixin_38241876/article/details/91490046
今日推荐
周排行