本文主要是介绍【数据处理】pth文件读取,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
1. 数据处理
首先将json文件(如下),经过一系列处理好保存在trainset.pth文件中
1.1 json文件数据预处理----trainset.pth文件
self.path_trainset = osp.join(self.subdir_processed, 'trainset.pth') #将vqa2.0json文件处理好后存放的地方def process(self):dir_ann = osp.join(self.dir_raw, 'annotations')path_train_ann = osp.join(dir_ann, 'mscoco_train2014_annotations.json')path_train_ques = osp.join(dir_ann, 'OpenEnded_mscoco_train2014_questions.json')train_ann = json.load(open(path_train_ann))train_ques = json.load(open(path_train_ques))trainset = self.merge_annotations_with_questions(train_ann, train_ques) #合并答案和question文件trainset = self.add_image_names(trainset) #向文件中添加图像名trainset['annotations'] = self.add_answer(trainset['annotations']) #向文件中添加答案trainset['annotations'] = self.tokenize_answers(trainset['annotations']) #对答案进行tokenize处理trainset['questions'] = self.tokenize_questions(trainset['questions'], self.nlp) #对问题采用nlp进行tokenize处理trainset['questions'] = self.insert_UNK_token(trainset['questions'], wcounts, self.minwcount)trainset['questions'] = self.encode_questions(trainset['questions'], word_to_wid)trainset['annotations'] = self.encode_answers(trainset['annotations'], ans_to_aid)torch.save(trainset, self.path_trainset) #保存处理好后的json文件到trainset.pth中
#加载数据集
if not os.path.exists(self.subdir_processed):self.process()
self.dataset = torch.load(self.path_trainset)
’questions’
’annotations’
1.2 获取faster-rcnn提取好的图像特征信息
#添加rcnn提取的信息def add_rcnn_to_item(self, item):''':param item: 传入的coco/extract/coco_train*******.jpg.pth文件:return:'''path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name']))item_rcnn = torch.load(path_rcnn) #加载pth文件print(item_rcnn)item['visual'] = item_rcnn['pooled_feat'] #区域特征item['coord'] = item_rcnn['rois'] #感兴趣区域位置item['norm_coord'] = item_rcnn['norm_rois'] #感兴趣区域特征标准化item['nb_regions'] = item['visual'].size(0) #区域数return item
1.3 向传入模型的数据中添加1.1处理好的trainset.pth信息和faster-rcnn提取好的图像特征信息
def __getitem__(self, index):item = {}item['index'] = index# Process Question (word token)question = self.dataset['questions'][index]if self.load_original_annotation:item['original_question'] = questionitem['question_id'] = question['question_id'] #向item中添加问题id:question_iditem['question'] = torch.LongTensor(question['question_wids']) #向item添加问题单词索引表示:questionitem['lengths'] = torch.LongTensor([len(question['question_wids'])]) #向item添加问题长度:lengthsitem['image_name'] = question['image_name'] #向item添加图像名:image_name# Process Object, Attribut and Relational features# 处理对象、特性和关系特征item = self.add_rcnn_to_item(item) #向item中添加由faster-rcnn提取好的图像特征信息 :boxes,feature# 如果答案存在,处理答案(主要是因为测试集没有答案,所有处理训练集)if 'annotations' in self.dataset:annotation = self.dataset['annotations'][index]if self.load_original_annotation:item['original_annotation'] = annotationif 'train' in self.split and self.samplingans:proba = annotation['answers_count']proba = proba / np.sum(proba)item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba))else:item['answer_id'] = annotation['answer_id']item['class_id'] = torch.LongTensor([item['answer_id']])item['answer'] = annotation['answer']item['question_type'] = annotation['question_type']else:if item['question_id'] in self.is_qid_testdev:item['is_testdev'] = Trueelse:item['is_testdev'] = Falsereturn item
整个item字典中键有
{
index :索引,
question_id:问题id,458752001
question:问题, tensor([4321, 2932, 1997, 3968, 2286, 2878])
lengths:问题长度, tensor([6]),
image_name:图像名,'COCO_train2014_000000458752.jpg'
visual:图像特征,
coord:感兴趣区域位置信息,
norm_coord:感兴趣区域位置信息标准化,
nb_regions:区域数,36
answer_id:答案id,382
class_id:分类id, tensor([382])
answer:答案,'pitcher'
question_type:问题类型 'what'
}
如下item数据具体信息:
{ 'index': 1, 'question_id': 458752001,'question': tensor([4321, 2932, 1997, 3968, 2286, 2878]), 'lengths': tensor([6]), 'image_name': 'COCO_train2014_000000458752.jpg', 'visual': tensor([[0.0000, 0.0000, 0.0231, ..., 0.0000, 0.0281, 1.5262],[0.0000, 0.0169, 0.0587, ..., 0.0000, 0.0064, 1.1313],[0.3978, 0.0000, 0.0000, ..., 0.0000, 0.1113, 3.8770],...,[0.0326, 0.0000, 0.0000, ..., 0.0799, 2.7793, 1.2371],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.3857],[0.0084, 0.8026, 0.0966, ..., 0.0000, 0.7668, 0.0798]]), 'coord': tensor([[282.9814, 302.8545, 372.2248, 468.0808],[311.6291, 333.7408, 359.6907, 358.5282],[215.0726, 172.1102, 352.8074, 407.3577],[285.7687, 189.5694, 329.6428, 231.0218],[274.9748, 160.3990, 318.7502, 208.2673],[241.7279, 302.4235, 286.6740, 342.9375],[241.9454, 230.9683, 355.9464, 350.7243],[ 0.0000, 0.0000, 383.6760, 425.9977],[372.8926, 360.4357, 629.3776, 401.7158],[348.9116, 45.5669, 639.2000, 479.2000],[391.1129, 149.4865, 610.5438, 353.4078],[ 28.7101, 178.1590, 235.1222, 398.0370],[353.4412, 420.1210, 381.8891, 442.3383],[249.2581, 316.8993, 319.8651, 357.7612],[ 0.0000, 0.0000, 487.8425, 174.9353],[177.7185, 63.3202, 472.4503, 479.2000],[ 6.3949, 120.7137, 639.2000, 479.2000],[ 88.5412, 209.1507, 558.1020, 479.2000],[254.3652, 164.8777, 332.2064, 206.6262],[189.8386, 128.2660, 426.3345, 479.2000],[237.2819, 281.1407, 411.9520, 479.2000],[ 20.9822, 370.2453, 301.8062, 402.6493],[312.2184, 263.2010, 344.6071, 296.3635],[265.3174, 229.0582, 374.0845, 349.4183],[257.8582, 154.3860, 341.2274, 235.9603],[108.7576, 342.0231, 573.0241, 455.5504],[ 57.4191, 0.0000, 617.8732, 117.8669],[234.5487, 271.3556, 268.1855, 318.8475],[323.6842, 0.0000, 639.2000, 145.7849],[263.1414, 249.6308, 396.5386, 479.2000],[310.9734, 257.4800, 349.8676, 292.9267],[349.4448, 423.4623, 388.5869, 452.1093],[269.6038, 153.9579, 300.1459, 188.4087],[162.7299, 0.0000, 639.2000, 230.7880],[286.1820, 371.8325, 346.1609, 479.2000],[168.0096, 305.8445, 479.0755, 479.2000]]), 'norm_coord': tensor([[0.4422, 0.6309, 0.5816, 0.9752],[0.4869, 0.6953, 0.5620, 0.7469],[0.3361, 0.3586, 0.5513, 0.8487],[0.4465, 0.3949, 0.5151, 0.4813],[0.4296, 0.3342, 0.4980, 0.4339],[0.3777, 0.6300, 0.4479, 0.7145],[0.3780, 0.4812, 0.5562, 0.7307],[0.0000, 0.0000, 0.5995, 0.8875],[0.5826, 0.7509, 0.9834, 0.8369],[0.5452, 0.0949, 0.9988, 0.9983],[0.6111, 0.3114, 0.9540, 0.7363],[0.0449, 0.3712, 0.3674, 0.8292],[0.5523, 0.8753, 0.5967, 0.9215],[0.3895, 0.6602, 0.4998, 0.7453],[0.0000, 0.0000, 0.7623, 0.3644],[0.2777, 0.1319, 0.7382, 0.9983],[0.0100, 0.2515, 0.9988, 0.9983],[0.1383, 0.4357, 0.8720, 0.9983],[0.3974, 0.3435, 0.5191, 0.4305],[0.2966, 0.2672, 0.6661, 0.9983],[0.3708, 0.5857, 0.6437, 0.9983],[0.0328, 0.7713, 0.4716, 0.8389],[0.4878, 0.5483, 0.5384, 0.6174],[0.4146, 0.4772, 0.5845, 0.7280],[0.4029, 0.3216, 0.5332, 0.4916],[0.1699, 0.7125, 0.8954, 0.9491],[0.0897, 0.0000, 0.9654, 0.2456],[0.3665, 0.5653, 0.4190, 0.6643],[0.5058, 0.0000, 0.9988, 0.3037],[0.4112, 0.5201, 0.6196, 0.9983],[0.4859, 0.5364, 0.5467, 0.6103],[0.5460, 0.8822, 0.6072, 0.9419],[0.4213, 0.3207, 0.4690, 0.3925],[0.2543, 0.0000, 0.9988, 0.4808],[0.4472, 0.7747, 0.5409, 0.9983],[0.2625, 0.6372, 0.7486, 0.9983]]),'nb_regions': 36, 'answer_id': 382, 'class_id': tensor([382]),'answer': 'pitcher', 'question_type': 'what'}
这篇关于【数据处理】pth文件读取的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!