最开始借鉴的版本使用的是github上找的一个比较靠谱的版本,并且以前在跑过关于mnist的triplet代码编写,但是在调用cifar10数据集是发现并不能用,报出以下错误
错误出现大概就是dataloader的迭代器问题,对于太底层的东西不是很清楚,大概觉得是格式问题,当时觉得github上给的代码应该是没问题的,因为毕竟已经用mnist的数据集测试过了,主要错误在__getitem__中,我当时写的_getitem__是这样的,
def __getitem__(self, index): path1, path2, path3 = self.triplets[index] img1 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path1)])) img2 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path2)])) img3 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path3)])) if self.transform is not None: img1 = self.transform(img1) img2 = self.transform(img2) img3 = self.transform(img3) return img1, img2, img3
后来总是在迭代器上报错:
Traceback (most recent call last):
File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1689, in <module>
main()
File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1683, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1083, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "E:/Retrieval/First-reproduction-mutlilabel/utils/DataProcessing.py", line 157, in <module>
for batch_idx, triplet_train_data in enumerate(dset_triplet_train_loader, 0):#train_input, train_label, batch_ind
File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 322, in __next__
return self._process_next_batch(batch)
File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 357, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
IndexError: Traceback (most recent call last):
File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 106, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 106, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "E:\Retrieval\First-reproduction-mutlilabel\utils\DataProcessing.py", line 93, in __getitem__
img3 = Image.open(os.path.join(self.data_path, self.img_filename_lists[path3 + 1]))
IndexError: list index out of range
根据raise batch.exc_type(batch.exc_msg) google搜索出这么一篇博客https://discuss.pytorch.org/t/cannot-unsqueeze-empty-tensor/1300,说的和最终发现的问题有点联系,我在怀疑应该是在__getitem__上return 三张图片 而一般不是triplet网络是返回的一张图片,一个标签,一个索引号,所以应该会在调用时 发生数据类型相关的错误,但是不是太怀疑的,因为我一直坚信使用前面mnist跑过 return img1, img2, img3,返回三张图片不加一个标签,一个索引号是没问题的,,最终查了一整天的问题还是没查出来,妥协了,参考了github上的一个triplet数据加载代码。https://github.com/CaptainEven/FaceRecognition 重改的Triplet数据加载代码,变成这样:
class RamMakTrilet(Dataset): def __init__(self, root, img_name_file, num_cls = 10, num_triplets = 1000, limit = 500, transforms = None, ): self.transforms = transforms self.triplets = [Sel.select_triplet(root,img_name_file, num_cls, limit) for i in range(num_triplets)] def __getitem__(self, index): ''' 每次返回一个triplet ''' triplet = self.triplets[index] # print(triplet) data = [self.transforms(Image.open(img_path)) for img_path in triplet[:3]] label = triplet[3:] # print(label) return data, label
def select_triplet(dir, img_name_file, num_classes, limit=500, is_car=False): ''' 从triplet_dir中随机选择一个三元组 ''' # np.random.seed(100) # 设置固定的随机数种子,便于验证 # 获取anchor, positive, negative 图片ID anchor_cls = np.random.choice(num_classes)#0~9 anchor_id = anchor_cls*limit + np.random.choice(limit) positive_id = anchor_cls*limit + \ get_negative_id(limit, anchor_id - anchor_cls*limit) negative_cls = get_negative_id(num_classes, anchor_cls) # 随机选择一个反例类型 negative_id = negative_cls*limit + np.random.choice(limit) # 随机选择一个反例ID img_name = os.path.join(dir, img_name_file) # 获取anchor, positive, negative 图片地址 fp = open(img_name, 'r') # 这里是open文件 img_filename = [x.strip() for x in fp] # 返回照片文件名 这个是list格式 fp.close() anchor_path = os.path.join(dir, img_filename[anchor_id]) positive_path = os.path.join(dir, img_filename[positive_id]) negative_path = os.path.join(dir, img_filename[negative_id]) # 返回6元组 return anchor_path, positive_path, negative_path, anchor_cls, anchor_cls, negative_cls 最终在主函数中调用如下:
for batch_idx, triplet_train_data in enumerate(train_loader, 0):#train_input, train_label, batch_ind data, _ = triplet_train_data data1, data2, data3 = data
完美解决问题