1.将图片的路径和标签写入csv文件并实现读取
1 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 2 def load_csv(self,filename): 3 if not os.path.exists(os.path.join(self.root,filename)): 4 images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断 5 for name in self.name2label.keys(): 6 # pokemeon\\mew\\0001.jpg mew可以通过字典查看其类别 7 images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径 8 images += glob.glob(os.path.join(self.root,name,'*.jpg')) 9 random.shuffle(images) 10 with open(os.path.join(self.root,filename),'w') as f: 11 writer = csv.writer(f) 12 for img in images: 13 name = img.split(os.sep) 14 label = self.name2label[name[-2]] 15 writer.writerow([img,label]) 16 17 # 从csv中读取文件 18 images, labels = [], [] 19 with open(os.path.join(self.root,filename),'r') as f: 20 reader = csv.reader(f) 21 for row in reader: 22 img,label = row 23 label = int(label) 24 images.append(img) 25 labels.append(label) 26 assert len(images) == len(labels) # 保证数据长度一致
return images,labels