Triplet数据集导入遇到的BUG

最开始借鉴的版本使用的是github上找的一个比较靠谱的版本,并且以前在跑过关于mnist的triplet代码编写,但是在调用cifar10数据集是发现并不能用,报出以下错误

错误出现大概就是dataloader的迭代器问题,对于太底层的东西不是很清楚,大概觉得是格式问题,当时觉得github上给的代码应该是没问题的,因为毕竟已经用mnist的数据集测试过了,主要错误在__getitem__中,我当时写的_getitem__是这样的,

def __getitem__(self, index):
    path1, path2, path3 = self.triplets[index]
    img1 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path1)]))
    img2 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path2)]))
    img3 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path3)]))
    if self.transform is not None:
        img1 = self.transform(img1)
        img2 = self.transform(img2)
        img3 = self.transform(img3)

    return img1, img2, img3

后来总是在迭代器上报错:

Traceback (most recent call last):
  File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1689, in <module>
    main()
  File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1683, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1083, in run
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "E:/Retrieval/First-reproduction-mutlilabel/utils/DataProcessing.py", line 157, in <module>
    for batch_idx, triplet_train_data in enumerate(dset_triplet_train_loader, 0):#train_input, train_label, batch_ind 
  File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 322, in __next__
    return self._process_next_batch(batch)
  File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 357, in _process_next_batch
    raise batch.exc_type(batch.exc_msg)
IndexError: Traceback (most recent call last):
  File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 106, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "E:\Retrieval\First-reproduction-mutlilabel\utils\DataProcessing.py", line 93, in __getitem__
    img3 = Image.open(os.path.join(self.data_path, self.img_filename_lists[path3 + 1]))
IndexError: list index out of range
根据raise batch.exc_type(batch.exc_msg) google搜索出这么一篇博客https://discuss.pytorch.org/t/cannot-unsqueeze-empty-tensor/1300,说的和最终发现的问题有点联系,我在怀疑应该是在__getitem__上return 三张图片 而一般不是triplet网络是返回的一张图片,一个标签,一个索引号,所以应该会在调用时 发生数据类型相关的错误,但是不是太怀疑的,因为我一直坚信使用前面mnist跑过 return img1, img2, img3,返回三张图片不加一个标签,一个索引号是没问题的,,最终查了一整天的问题还是没查出来,妥协了,参考了github上的一个triplet数据加载代码。https://github.com/CaptainEven/FaceRecognition 重改的Triplet数据加载代码,变成这样:

class RamMakTrilet(Dataset):
    def __init__(self,
                 root,
                 img_name_file,
                 num_cls = 10,
                 num_triplets = 1000,
                 limit = 500,
                 transforms = None,
                 ):
        self.transforms = transforms
        self.triplets = [Sel.select_triplet(root,img_name_file, num_cls, limit)
                         for i in range(num_triplets)]

    def __getitem__(self, index):
        '''
        每次返回一个triplet
        '''
        triplet = self.triplets[index]
        # print(triplet)
        data = [self.transforms(Image.open(img_path))
                for img_path in triplet[:3]]
        label = triplet[3:]
        # print(label)
        return data, label
def select_triplet(dir, img_name_file, num_classes, limit=500, is_car=False):
    '''
    从triplet_dir中随机选择一个三元组
    '''
    # np.random.seed(100)  # 设置固定的随机数种子,便于验证

    # 获取anchor, positive, negative 图片ID
    anchor_cls = np.random.choice(num_classes)#0~9
    anchor_id = anchor_cls*limit + np.random.choice(limit)
    positive_id = anchor_cls*limit + \
        get_negative_id(limit, anchor_id - anchor_cls*limit)
    negative_cls = get_negative_id(num_classes, anchor_cls)  # 随机选择一个反例类型
    negative_id = negative_cls*limit + np.random.choice(limit)  # 随机选择一个反例ID
    img_name = os.path.join(dir, img_name_file)
    # 获取anchor, positive, negative 图片地址
    fp = open(img_name, 'r')  # 这里是open文件
    img_filename = [x.strip() for x in fp]  # 返回照片文件名 这个是list格式
    fp.close()
    anchor_path = os.path.join(dir, img_filename[anchor_id])
    positive_path = os.path.join(dir, img_filename[positive_id])
    negative_path = os.path.join(dir, img_filename[negative_id])

    # 返回6元组
    return anchor_path, positive_path, negative_path, anchor_cls, anchor_cls, negative_cls
最终在主函数中调用如下:
for batch_idx, triplet_train_data in enumerate(train_loader, 0):#train_input, train_label, batch_ind 
    data, _ = triplet_train_data
    data1, data2, data3 = data

完美解决问题

猜你喜欢

转载自blog.csdn.net/qq_33824968/article/details/84927329