torch_13_自定义数据集

1.将图片的路径和标签写入csv文件并实现读取

 1  # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0
 2     def load_csv(self,filename):
 3         if not os.path.exists(os.path.join(self.root,filename)):
 4             images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断
 5             for name in self.name2label.keys():
 6                 # pokemeon\\mew\\0001.jpg mew可以通过字典查看其类别
 7                 images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径
 8                 images += glob.glob(os.path.join(self.root,name,'*.jpg'))
 9             random.shuffle(images)
10             with open(os.path.join(self.root,filename),'w') as f:
11                 writer = csv.writer(f)
12                 for img in images:
13                     name = img.split(os.sep)
14                     label = self.name2label[name[-2]]
15                     writer.writerow([img,label])
16 
17          # 从csv中读取文件
18         images, labels = [], []
19         with open(os.path.join(self.root,filename),'r') as f:
20             reader = csv.reader(f)
21             for row in reader:
22                 img,label = row
23                 label = int(label)
24                 images.append(img)
25                 labels.append(label)
26         assert len(images) == len(labels) # 保证数据长度一致
       return images,labels

猜你喜欢

转载自www.cnblogs.com/shuangcao/p/11905505.html