使用kaist数据集训练tf-faster-rcnn

2023-11-23 01:50

本文主要是介绍使用kaist数据集训练tf-faster-rcnn,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

使用kaist数据集训练tf-faster-rcnn
首先是修改pascal_voc.py,新建自己的kaist_rgb.py
pacal_voc的数据格式很麻烦,annotation是xml格式,自己的数据一般都是txt,所以,不会像读取pascal_voc数据那么复杂。下面是前人在pascal_voc.py 基础上写的kaist Dataset的接口:

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick 
# --------------------------------------------------------#import datasets.caltech
import os
from datasets.imdb import imdb
import xml.dom.minidom as minidom
import numpy as np
import scipy.sparse
import scipy.io as sio
import utils.cython_bbox
import pickle
import subprocess
from model.config import cfgclass kaist_rgb(imdb):def __init__(self, image_set):imdb.__init__(self, 'kaist_' + image_set)  # image_set: train04 or testself._image_set = image_setself._devkit_path = self._get_default_path()self._data_path = self._get_default_path()self._classes = ('__background__', 'pedestrian')self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))# self._class_to_ind = {'__background__': 0, 'pedestrian': 1}self._image_ext = '.jpg'self._image_index = self._load_image_set_index()# Default to roidb handlerself._roidb_handler = self.selective_search_roidb# PASCAL specific config optionsself.config = {'cleanup': True,'use_salt': True,'use_diff': False,'matlab_eval': False,'rpn_file': None,'min_size': 2}assert os.path.exists(self._devkit_path), 'VOCdevkit path does not exist: {}'.format(self._devkit_path)assert os.path.exists(self._data_path), 'Path does not exist: {}'.format(self._data_path)def image_path_at(self, i):"""Return the absolute path to image i in the image sequence."""return self.image_path_from_index(self._image_index[i])def image_path_from_index(self, index):"""Construct an image path from the image's "index" identifier."""# image_path = os.path.join(self._data_path, self._image_set, 'images', index + self._image_ext)image_path = os.path.join(self._data_path, self._image_set, 'images', index[:-6] + 'visible/' + index[-6:] + self._image_ext)assert os.path.exists(image_path), 'Path does not exist: {}'.format(image_path)return image_pathdef _load_image_set_index(self):"""Load the indexes listed in this dataset's image set file."""# Example path to image set file:# self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txtimage_set_file = os.path.join(self._data_path, self._image_set, self._image_set + '.txt')assert os.path.exists(image_set_file), 'Path does not exist: {}'.format(image_set_file)with open(image_set_file) as f:image_index = [x.strip() for x in f.readlines()]return image_indexdef _get_default_path(self):"""Return the default path where kaist dataset is expected to be installed."""return os.path.join(cfg.DATA_DIR, 'kaist')def gt_roidb(self):"""Return the database of ground-truth regions of interest.This function loads/saves from/to a cache file to speed up future calls."""cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')if os.path.exists(cache_file):with open(cache_file, 'rb') as fid:try:roidb = pickle.load(fid)except:roidb = pickle.load(fid, encoding='bytes')print ('{} gt roidb loaded from {}'.format(self.name, cache_file))#print (roidb)#for dic in roidb:#    print (dic['gt_overlaps'])return roidbgt_roidb = [self._load_revised_annotation(index)for index in self.image_index]#print (gt_roidb)with open(cache_file, 'wb') as fid:pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)print ('wrote gt roidb to {}'.format(cache_file))return gt_roidbdef selective_search_roidb(self):"""Return the database of selective search regions of interest.Ground-truth ROIs are also included.This function loads/saves from/to a cache file to speed up future calls."""cache_file = os.path.join(self.cache_path,self.name + '_selective_search_roidb.pkl')if os.path.exists(cache_file):with open(cache_file, 'rb') as fid:roidb = pickle.load(fid)print ('{} ss roidb loaded from {}'.format(self.name, cache_file))return roidbif self._image_set != 'test-all':gt_roidb = self.gt_roidb()ss_roidb = self._load_selective_search_roidb(gt_roidb)roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)else:roidb = self._load_selective_search_roidb(None)with open(cache_file, 'wb') as fid:pickle.dump(roidb, fid, pickle.HIGHEST_PROTOCOL)print('wrote ss roidb to {}'.format(cache_file))return roidbdef _load_selective_search_roidb(self, gt_roidb):filename = os.path.abspath(os.path.join(self.cache_path, '..','selective_search_data',self.name + '.mat'))assert os.path.exists(filename), 'Selective search data not found at: {}'.format(filename)raw_data = sio.loadmat(filename)['boxes'].ravel()box_list = []for i in xrange(raw_data.shape[0]):box_list.append(raw_data[i][:, :] - 1)return self.create_roidb_from_box_list(box_list, gt_roidb)def selective_search_IJCV_roidb(self):"""Return the database of selective search regions of interest.Ground-truth ROIs are also included.This function loads/saves from/to a cache file to speed up future calls."""cache_file = os.path.join(self.cache_path,'{:s}_selective_search_IJCV_top_{:d}_roidb.pkl'.format(self.name, self.config['top_k']))if os.path.exists(cache_file):with open(cache_file, 'rb') as fid:roidb = pickle.load(fid)print ('{} ss roidb loaded from {}'.format(self.name, cache_file))gt_roidb = self.gt_roidb()ss_roidb = self._load_selective_search_IJCV_roidb(gt_roidb)roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)with open(cache_file, 'wb') as fid:pickle.dump(roidb, fid, pickle.HIGHEST_PROTOCOL)print ('wrote ss roidb to {}'.format(cache_file))return roidbdef rpn_roidb(self):if self._image_set != 'test-all':gt_roidb = self.gt_roidb()rpn_roidb = self._load_rpn_roidb(gt_roidb)roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)else:roidb = self._load_rpn_roidb(None)return roidbdef _load_rpn_roidb(self, gt_roidb):filename = self.config['rpn_file']print ('loading {}'.format(filename))assert os.path.exists(filename), 'rpn data not found at: {}'.format(filename)with open(filename, 'rb') as f:box_list = pickle.load(f)return self.create_roidb_from_box_list(box_list, gt_roidb)def _load_selective_search_IJCV_roidb(self, gt_roidb):IJCV_path = os.path.abspath(os.path.join(self.cache_path, '..','selective_search_IJCV_data','voc_' + self._year))assert os.path.exists(IJCV_path), 'Selective search IJCV data not found at: {}'.format(IJCV_path)top_k = self.config['top_k']box_list = []for i in xrange(self.num_images):filename = os.path.join(IJCV_path, self.image_index[i] + '.mat')raw_data = sio.loadmat(filename)box_list.append((raw_data['boxes'][:top_k, :] - 1).astype(np.uint16))return self.create_roidb_from_box_list(box_list, gt_roidb)def _load_revised_annotation(self, index):"""Load image and bounding boxes info from text file in the kaist dataset format."""filename = os.path.join(self._data_path, self._image_set, 'annotations', index + '.txt')# print ('Loading: {}'.format(filename))with open(filename) as f:lines = f.readlines()[1:]num_objs = len(lines)boxes = np.zeros((num_objs, 4), dtype=np.uint16)gt_classes = np.zeros((num_objs), dtype=np.int32)overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)seg_areas = np.zeros((num_objs), dtype=np.float32)# Load object bounding boxes into a data frame.ix = 0for obj in lines:# Make pixel indexes 0-basedinfo = obj.split()# jamif self._image_set.find("train") != -1:if info[0] == "person":x1 = float(info[1])y1 = float(info[2])x2 = x1 + float(info[3])y2 = y1 + float(info[4])assert(x2 >= x1)assert(y2 >= y1)cls = self._class_to_ind['pedestrian']boxes[ix, :] = [max(x1 - 1, 0), max(y1 - 1, 0), min(x2 - 1, 639), min(y2 - 1, 479)]gt_classes[ix] = clsoverlaps[ix, cls] = 1.0seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)ix = ix + 1overlaps = scipy.sparse.csr_matrix(overlaps)return {'boxes': boxes,'gt_classes': gt_classes,'gt_overlaps': overlaps,'flipped': False,'seg_areas': seg_areas}def _write_voc_results_file(self, all_boxes):use_salt = self.config['use_salt']comp_id = 'comp4'if use_salt:comp_id += '-{}'.format(os.getpid())# VOCdevkit/results/VOC2007/Main/comp4-44503_det_test_aeroplane.txtpath = os.path.join(self._devkit_path, 'results', 'VOC', 'Main', comp_id + '_')for cls_ind, cls in enumerate(self.classes):if cls == '__background__':continueprint ('Writing {} VOC results file'.format(cls))filename = path + 'det_' + self._image_set + '_' + cls + '.txt'with open(filename, 'wt') as f:for im_ind, index in enumerate(self.image_index):dets = all_boxes[cls_ind][im_ind]if dets == []:continue# the VOCdevkit expects 1-based indicesfor k in xrange(dets.shape[0]):f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.format(index, dets[k, -1],dets[k, 0] + 1, dets[k, 1] + 1,dets[k, 2] + 1, dets[k, 3] + 1))return comp_iddef _do_matlab_eval(self, comp_id, output_dir='output'):rm_results = self.config['cleanup']path = os.path.join(os.path.dirname(__file__),'VOCdevkit-matlab-wrapper')cmd = 'cd {} && '.format(path)cmd += '{:s} -nodisplay -nodesktop '.format(datasets.MATLAB)cmd += '-r "dbstop if error; 'cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\',{:d}); quit;"' \.format(self._devkit_path, comp_id,self._image_set, output_dir, int(rm_results))print('Running:\n{}'.format(cmd))status = subprocess.call(cmd, shell=True)def evaluate_detections(self, all_boxes, output_dir):comp_id = self._write_voc_results_file(all_boxes)self._do_matlab_eval(comp_id, output_dir)def competition_mode(self, on):if on:self.config['use_salt'] = Falseself.config['cleanup'] = Falseelse:self.config['use_salt'] = Trueself.config['cleanup'] = Trueif __name__ == '__main__':d = datasets.kaist('train20')res = d.roidbfrom IPython import embedembed()
2018-09-06 13:11:49.699197: W tensorflow/core/framework/op_kernel.cc:1263] Invalid argument: ValueError: attempt to get argmax of an empty sequence
Traceback (most recent call last):File "/home/ramsey/.local/lib/python3.5/site-packages/tensorflow/python/ops/script_ops.py", line 206, in __call__ret = func(*args)File "/home/ramsey/tf-faster-rcnn/tools/../lib/layer_utils/anchor_target_layer.py", line 57, in anchor_target_layerargmax_overlaps = overlaps.argmax(axis=1)ValueError: attempt to get argmax of an empty sequence

这里附上数据文件的组织形式。
这里写图片描述

这里写图片描述
通过输出roidd的部分信息以及计算得到的overlaps,发现,validation roidb的gt_boxes是空的。·(如下图)
这里写图片描述
这里写图片描述

最开始没有找到原因,于是注释掉了lib/model/train_val.py中train_model方法的check validation data部分,想先跑通训练。(如下图)
这里写图片描述
注释掉以后,成功跑通训练。单在迭代一定次数后,又出现了之前的问题:

2018-09-06 13:11:49.699197: W tensorflow/core/framework/op_kernel.cc:1263] Invalid argument: ValueError: attempt to get argmax of an empty sequence
Traceback (most recent call last):File "/home/ramsey/.local/lib/python3.5/site-packages/tensorflow/python/ops/script_ops.py", line 206, in __call__ret = func(*args)File "/home/ramsey/tf-faster-rcnn/tools/../lib/layer_utils/anchor_target_layer.py", line 57, in anchor_target_layerargmax_overlaps = overlaps.argmax(axis=1)ValueError: attempt to get argmax of an empty sequence

于是,尝试输出用于training的roidb(是filt 空gt box之后的roidb),发现存在部分roidb的gt_box的四个坐标是0。
这里写图片描述
所以,从这里猜想,应该是数据出了问题。
根据输出日志,找到相应的数据例如,上面的I02759.jpg对应的annotation
这里写图片描述
这个标注很奇怪!因为其他正确的gt_box的annotation都是person.
于是,找到KAIST DATASET的原论文(Multispectral Pedestrian Detection: Benchmark Dataset and Baseline)
其中有提到:”Obviously an individual pedestrian was labelled as a person. Not distinguishable individuals were labeled as people. People riding a two-wheeled vehicle were labeled as cyclist. In a highly cluttered scene, even human annotators sometimes cannot clearly determine whether a human shaped object is a pedestrian or not. This object is labeled as person? an it is ignored in the validation“
所以,KAIST 数据集的annotation不光只有person,还有cyclist, person? people。

而在kaist_rgb.py中,load_annotation函数,只读取了annotation是person的数据,但在创建box的numy的时候,又根据的是读取的数据的行数:boxes = np.zeros((num_objs, 4), dtype=np.uint16), 所以导致部分gt_boxes的存在但四个坐标都是0. 这样是为什么明明有filter_roidb函数,但仍然出错的原因。因为tf_faster_rcnn的filter_roidb(roidb)函数只能去掉gt_box为空的roidb,不能去掉gt_box的四个坐标都为0的roidb

  def _load_revised_annotation(self, index):"""Load image and bounding boxes info from text file in the kaist dataset format."""filename = os.path.join(self._data_path, self._image_set, 'annotations', index + '.txt')# print ('Loading: {}'.format(filename))with open(filename) as f:lines = f.readlines()[1:]num_objs = len(lines)boxes = np.zeros((num_objs, 4), dtype=np.uint16)gt_classes = np.zeros((num_objs), dtype=np.int32)overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)seg_areas = np.zeros((num_objs), dtype=np.float32)# Load object bounding boxes into a data frame.ix = 0for obj in lines:# Make pixel indexes 0-basedinfo = obj.split()# jamif self._image_set.find("train") != -1:if info[0] == "person":x1 = float(info[1])y1 = float(info[2])x2 = x1 + float(info[3])y2 = y1 + float(info[4])assert(x2 >= x1)assert(y2 >= y1)cls = self._class_to_ind['pedestrian']boxes[ix, :] = [max(x1 - 1, 0), max(y1 - 1, 0), min(x2 - 1, 639), min(y2 - 1, 479)]gt_classes[ix] = clsoverlaps[ix, cls] = 1.0seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)ix = ix + 1overlaps = scipy.sparse.csr_matrix(overlaps)return {'boxes': boxes,'gt_classes': gt_classes,'gt_overlaps': overlaps,'flipped': False,'seg_areas': seg_areas}

所以,现在修改_load_revised_annotation(self, index)函数,下面是修改之后的,(就简单的把上述的person,person?people cyclist都认为是pedestrian)

    def _load_revised_annotation(self, index):"""Load image and bounding boxes info from text file in the kaist dataset format."""filename = os.path.join(self._data_path, self._image_set, 'annotations', index + '.txt')# print ('Loading: {}'.format(filename))with open(filename) as f:lines = f.readlines()[1:]num_objs = len(lines)boxes = np.zeros((num_objs, 4), dtype=np.uint16)gt_classes = np.zeros((num_objs), dtype=np.int32)overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)seg_areas = np.zeros((num_objs), dtype=np.float32)# Load object bounding boxes into a data frame.ix = 0for obj in lines:# Make pixel indexes 0-basedinfo = obj.split()# jam#if self._image_set.find("train") != -1:#    if info[0] == "person":x1 = float(info[1])y1 = float(info[2])x2 = x1 + float(info[3])y2 = y1 + float(info[4])assert(x2 >= x1)assert(y2 >= y1)cls = self._class_to_ind['pedestrian']boxes[ix, :] = [max(x1 - 1, 0), max(y1 - 1, 0), min(x2 - 1, 639), min(y2 - 1, 479)]gt_classes[ix] = clsoverlaps[ix, cls] = 1.0seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)ix = ix + 1overlaps = scipy.sparse.csr_matrix(overlaps)return {'boxes': boxes,'gt_classes': gt_classes,'gt_overlaps': overlaps,'flipped': False,'seg_areas': seg_areas}

注意,要删除原来生成的 cache里的文件
这里写图片描述
否则还是和之前一样。
这样就终于训练成功了。

这篇关于使用kaist数据集训练tf-faster-rcnn的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/414453

相关文章

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

Python中注释使用方法举例详解

《Python中注释使用方法举例详解》在Python编程语言中注释是必不可少的一部分,它有助于提高代码的可读性和维护性,:本文主要介绍Python中注释使用方法的相关资料,需要的朋友可以参考下... 目录一、前言二、什么是注释?示例:三、单行注释语法:以 China编程# 开头,后面的内容为注释内容示例:示例:四

Go语言数据库编程GORM 的基本使用详解

《Go语言数据库编程GORM的基本使用详解》GORM是Go语言流行的ORM框架,封装database/sql,支持自动迁移、关联、事务等,提供CRUD、条件查询、钩子函数、日志等功能,简化数据库操作... 目录一、安装与初始化1. 安装 GORM 及数据库驱动2. 建立数据库连接二、定义模型结构体三、自动迁

MyBatisPlus如何优化千万级数据的CRUD

《MyBatisPlus如何优化千万级数据的CRUD》最近负责的一个项目,数据库表量级破千万,每次执行CRUD都像走钢丝,稍有不慎就引起数据库报警,本文就结合这个项目的实战经验,聊聊MyBatisPl... 目录背景一、MyBATis Plus 简介二、千万级数据的挑战三、优化 CRUD 的关键策略1. 查

python实现对数据公钥加密与私钥解密

《python实现对数据公钥加密与私钥解密》这篇文章主要为大家详细介绍了如何使用python实现对数据公钥加密与私钥解密,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录公钥私钥的生成使用公钥加密使用私钥解密公钥私钥的生成这一部分,使用python生成公钥与私钥,然后保存在两个文

mysql中的数据目录用法及说明

《mysql中的数据目录用法及说明》:本文主要介绍mysql中的数据目录用法及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、版本3、数据目录4、总结1、背景安装mysql之后,在安装目录下会有一个data目录,我们创建的数据库、创建的表、插入的

ModelMapper基本使用和常见场景示例详解

《ModelMapper基本使用和常见场景示例详解》ModelMapper是Java对象映射库,支持自动映射、自定义规则、集合转换及高级配置(如匹配策略、转换器),可集成SpringBoot,减少样板... 目录1. 添加依赖2. 基本用法示例:简单对象映射3. 自定义映射规则4. 集合映射5. 高级配置匹

Spring 框架之Springfox使用详解

《Spring框架之Springfox使用详解》Springfox是Spring框架的API文档工具,集成Swagger规范,自动生成文档并支持多语言/版本,模块化设计便于扩展,但存在版本兼容性、性... 目录核心功能工作原理模块化设计使用示例注意事项优缺点优点缺点总结适用场景建议总结Springfox 是

嵌入式数据库SQLite 3配置使用讲解

《嵌入式数据库SQLite3配置使用讲解》本文强调嵌入式项目中SQLite3数据库的重要性,因其零配置、轻量级、跨平台及事务处理特性,可保障数据溯源与责任明确,详细讲解安装配置、基础语法及SQLit... 目录0、惨痛教训1、SQLite3环境配置(1)、下载安装SQLite库(2)、解压下载的文件(3)、