【数据处理】pth文件读取

1. 数据处理

首先将json文件(如下),经过一系列处理好保存在trainset.pth文件中

1.1 json文件数据预处理----trainset.pth文件

        self.path_trainset = osp.join(self.subdir_processed, 'trainset.pth') #将vqa2.0json文件处理好后存放的地方

    def process(self):
        dir_ann = osp.join(self.dir_raw, 'annotations')
        path_train_ann = osp.join(dir_ann, 'mscoco_train2014_annotations.json')
        path_train_ques = osp.join(dir_ann, 'OpenEnded_mscoco_train2014_questions.json')
        train_ann = json.load(open(path_train_ann))
        train_ques = json.load(open(path_train_ques))
        trainset = self.merge_annotations_with_questions(train_ann, train_ques) #合并答案和question文件
        trainset = self.add_image_names(trainset) #向文件中添加图像名
        trainset['annotations'] = self.add_answer(trainset['annotations']) #向文件中添加答案
        trainset['annotations'] = self.tokenize_answers(trainset['annotations']) #对答案进行tokenize处理
        trainset['questions'] = self.tokenize_questions(trainset['questions'], self.nlp) #对问题采用nlp进行tokenize处理
        trainset['questions'] = self.insert_UNK_token(trainset['questions'], wcounts, self.minwcount)
        trainset['questions'] = self.encode_questions(trainset['questions'], word_to_wid)
        trainset['annotations'] = self.encode_answers(trainset['annotations'], ans_to_aid)
        torch.save(trainset, self.path_trainset) #保存处理好后的json文件到trainset.pth中
#加载数据集
if not os.path.exists(self.subdir_processed):
	self.process()      
self.dataset = torch.load(self.path_trainset)

在这里插入图片描述
’questions’
在这里插入图片描述
’annotations’
在这里插入图片描述

1.2 获取faster-rcnn提取好的图像特征信息

    #添加rcnn提取的信息
    def add_rcnn_to_item(self, item):
        '''
        :param item: 传入的coco/extract/coco_train*******.jpg.pth文件
        :return:
        '''
        path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name']))
        item_rcnn = torch.load(path_rcnn) #加载pth文件
        print(item_rcnn)
        item['visual'] = item_rcnn['pooled_feat'] #区域特征
        item['coord'] = item_rcnn['rois'] #感兴趣区域位置
        item['norm_coord'] = item_rcnn['norm_rois'] #感兴趣区域特征标准化
        item['nb_regions'] = item['visual'].size(0) #区域数
        return item

在这里插入图片描述

1.3 向传入模型的数据中添加1.1处理好的trainset.pth信息和faster-rcnn提取好的图像特征信息

    def __getitem__(self, index):
        item = {
    
    }
        item['index'] = index

        # Process Question (word token)
        question = self.dataset['questions'][index]
        if self.load_original_annotation:
            item['original_question'] = question

        item['question_id'] = question['question_id'] #向item中添加问题id:question_id
        item['question'] = torch.LongTensor(question['question_wids']) #向item添加问题单词索引表示:question
        item['lengths'] = torch.LongTensor([len(question['question_wids'])]) #向item添加问题长度:lengths
        item['image_name'] = question['image_name'] #向item添加图像名:image_name

        # Process Object, Attribut and Relational features
        # 处理对象、特性和关系特征
        item = self.add_rcnn_to_item(item) #向item中添加由faster-rcnn提取好的图像特征信息 :boxes,feature

        # 如果答案存在,处理答案(主要是因为测试集没有答案,所有处理训练集)
        if 'annotations' in self.dataset:
            annotation = self.dataset['annotations'][index]
            if self.load_original_annotation:
                item['original_annotation'] = annotation
            if 'train' in self.split and self.samplingans:
                proba = annotation['answers_count']
                proba = proba / np.sum(proba)
                item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba))
            else:
                item['answer_id'] = annotation['answer_id']
            item['class_id'] = torch.LongTensor([item['answer_id']])
            item['answer'] = annotation['answer']
            item['question_type'] = annotation['question_type']
        else:
            if item['question_id'] in self.is_qid_testdev:
                item['is_testdev'] = True
            else:
                item['is_testdev'] = False
        return item

整个item字典中键有

{
index :索引,
question_id:问题id,458752001
question:问题, tensor([4321, 2932, 1997, 3968, 2286, 2878])
lengths:问题长度, tensor([6]), 
image_name:图像名,'COCO_train2014_000000458752.jpg'
visual:图像特征,
coord:感兴趣区域位置信息,
norm_coord:感兴趣区域位置信息标准化,
nb_regions:区域数,36
answer_id:答案id,382
class_id:分类id, tensor([382])
answer:答案,'pitcher'
question_type:问题类型  'what'
}

如下item数据具体信息:

{		'index': 1, 
		'question_id': 458752001,
		 'question': tensor([4321, 2932, 1997, 3968, 2286, 2878]), 'lengths': tensor([6]), 
		 'image_name': 'COCO_train2014_000000458752.jpg', 
		 'visual': tensor([[0.0000, 0.0000, 0.0231,  ..., 0.0000, 0.0281, 1.5262],
        [0.0000, 0.0169, 0.0587,  ..., 0.0000, 0.0064, 1.1313],
        [0.3978, 0.0000, 0.0000,  ..., 0.0000, 0.1113, 3.8770],
        ...,
        [0.0326, 0.0000, 0.0000,  ..., 0.0799, 2.7793, 1.2371],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.3857],
        [0.0084, 0.8026, 0.0966,  ..., 0.0000, 0.7668, 0.0798]]), 
        
        'coord': tensor([[282.9814, 302.8545, 372.2248, 468.0808],
        [311.6291, 333.7408, 359.6907, 358.5282],
        [215.0726, 172.1102, 352.8074, 407.3577],
        [285.7687, 189.5694, 329.6428, 231.0218],
        [274.9748, 160.3990, 318.7502, 208.2673],
        [241.7279, 302.4235, 286.6740, 342.9375],
        [241.9454, 230.9683, 355.9464, 350.7243],
        [  0.0000,   0.0000, 383.6760, 425.9977],
        [372.8926, 360.4357, 629.3776, 401.7158],
        [348.9116,  45.5669, 639.2000, 479.2000],
        [391.1129, 149.4865, 610.5438, 353.4078],
        [ 28.7101, 178.1590, 235.1222, 398.0370],
        [353.4412, 420.1210, 381.8891, 442.3383],
        [249.2581, 316.8993, 319.8651, 357.7612],
        [  0.0000,   0.0000, 487.8425, 174.9353],
        [177.7185,  63.3202, 472.4503, 479.2000],
        [  6.3949, 120.7137, 639.2000, 479.2000],
        [ 88.5412, 209.1507, 558.1020, 479.2000],
        [254.3652, 164.8777, 332.2064, 206.6262],
        [189.8386, 128.2660, 426.3345, 479.2000],
        [237.2819, 281.1407, 411.9520, 479.2000],
        [ 20.9822, 370.2453, 301.8062, 402.6493],
        [312.2184, 263.2010, 344.6071, 296.3635],
        [265.3174, 229.0582, 374.0845, 349.4183],
        [257.8582, 154.3860, 341.2274, 235.9603],
        [108.7576, 342.0231, 573.0241, 455.5504],
        [ 57.4191,   0.0000, 617.8732, 117.8669],
        [234.5487, 271.3556, 268.1855, 318.8475],
        [323.6842,   0.0000, 639.2000, 145.7849],
        [263.1414, 249.6308, 396.5386, 479.2000],
        [310.9734, 257.4800, 349.8676, 292.9267],
        [349.4448, 423.4623, 388.5869, 452.1093],
        [269.6038, 153.9579, 300.1459, 188.4087],
        [162.7299,   0.0000, 639.2000, 230.7880],
        [286.1820, 371.8325, 346.1609, 479.2000],
        [168.0096, 305.8445, 479.0755, 479.2000]]), 
        'norm_coord': tensor([[0.4422, 0.6309, 0.5816, 0.9752],
        [0.4869, 0.6953, 0.5620, 0.7469],
        [0.3361, 0.3586, 0.5513, 0.8487],
        [0.4465, 0.3949, 0.5151, 0.4813],
        [0.4296, 0.3342, 0.4980, 0.4339],
        [0.3777, 0.6300, 0.4479, 0.7145],
        [0.3780, 0.4812, 0.5562, 0.7307],
        [0.0000, 0.0000, 0.5995, 0.8875],
        [0.5826, 0.7509, 0.9834, 0.8369],
        [0.5452, 0.0949, 0.9988, 0.9983],
        [0.6111, 0.3114, 0.9540, 0.7363],
        [0.0449, 0.3712, 0.3674, 0.8292],
        [0.5523, 0.8753, 0.5967, 0.9215],
        [0.3895, 0.6602, 0.4998, 0.7453],
        [0.0000, 0.0000, 0.7623, 0.3644],
        [0.2777, 0.1319, 0.7382, 0.9983],
        [0.0100, 0.2515, 0.9988, 0.9983],
        [0.1383, 0.4357, 0.8720, 0.9983],
        [0.3974, 0.3435, 0.5191, 0.4305],
        [0.2966, 0.2672, 0.6661, 0.9983],
        [0.3708, 0.5857, 0.6437, 0.9983],
        [0.0328, 0.7713, 0.4716, 0.8389],
        [0.4878, 0.5483, 0.5384, 0.6174],
        [0.4146, 0.4772, 0.5845, 0.7280],
        [0.4029, 0.3216, 0.5332, 0.4916],
        [0.1699, 0.7125, 0.8954, 0.9491],
        [0.0897, 0.0000, 0.9654, 0.2456],
        [0.3665, 0.5653, 0.4190, 0.6643],
        [0.5058, 0.0000, 0.9988, 0.3037],
        [0.4112, 0.5201, 0.6196, 0.9983],
        [0.4859, 0.5364, 0.5467, 0.6103],
        [0.5460, 0.8822, 0.6072, 0.9419],
        [0.4213, 0.3207, 0.4690, 0.3925],
        [0.2543, 0.0000, 0.9988, 0.4808],
        [0.4472, 0.7747, 0.5409, 0.9983],
        [0.2625, 0.6372, 0.7486, 0.9983]]),
         'nb_regions': 36, 
         'answer_id': 382, 
         'class_id': tensor([382]),
          'answer': 'pitcher', 
          'question_type': 'what'}

猜你喜欢

转载自blog.csdn.net/snow_maple521/article/details/109637018