faster rcnn源码解析

2024-08-24 18:08
文章标签 源码 解析 faster rcnn

本文主要是介绍faster rcnn源码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

之前一直是使用faster rcnn对其中的代码并不是很了解,这次刚好复现mask rcnn就仔细阅读了faster rcnn,主要参考代码是pytorch-faster-rcnn ,部分参考和借用了以下博客的图片
[1] CNN目标检测(一):Faster RCNN详解

姊妹篇mask rcnn解析

整体框架

整体架构

  1. 首先图片进行放缩到W*H,然后送入vgg16(去掉了pool5),得到feature map(W/16, H/16)
  2. 然后feature map上每个点都对应原图上的9个anchor,送入rpn层后输出两个: 这9个anchor前背景的概率以及4个坐标的回归
  3. 每个anchor经过回归后对应到原图,然后再对应到feature map经过roi pooling后输出7*7大小的map
  4. 最后对这个7*7的map进行分类和再次回归
    (此处均为大体轮廓,具体细节见后面)

数据层

  1. 主要利用工厂模式适配各种数据集 factory.py中利用lambda表达式(泛函)
  2. 自定义适配自己数据集的类,继承于imdb
  3. 主要针对数据集中生成roidb,对于每个图片保持其中含有的所有的box坐标(0-index)及其类别,然后顺便保存它的面积等参数,最后记录所有图片的index及其根据index获取绝对地址的方法
# factory.py
from datasets.mydataset import mydataset
for dataset in ['xxdataset']:for split in ['train', 'val', 'test']:name = '{}_{}'.format(dataset, split)__sets[name] = (lambda split=split,dataset=dataset: mydataset(split, dataset))

RPN

这里写图片描述

anchors生成

经过feature extraction后,feature map的大小是(W/16, H/16), 记为(w,h),然后每个feature map每个点生成k个anchor,论文中设置了3中ratio, 3种scale 共产生了w*h*9个anchors
anchors

# # array([[ -83.,  -39.,  100.,   56.],
#       [-175.,  -87.,  192.,  104.],
#       [-359., -183.,  376.,  200.],
#       [ -55.,  -55.,   72.,   72.],
#       [-119., -119.,  136.,  136.],
#       [-247., -247.,  264.,  264.],
#       [ -35.,  -79.,   52.,   96.],
#       [ -79., -167.,   96.,  184.],
#       [-167., -343.,  184.,  360.]])
#  先以左上角(0,0)为例生成9个anchor,然后在向右向下移动,生成整个feature map所有点对应的anchor

anchors前背景和坐标预测

正如整体框架上画的那样,feature map后先跟了一个3*3的卷积,然后分别用2个1*1的卷积,预测feature map上每个点对应的9个anchor属于前背景的概率(9*2)和4个回归的坐标(9*4)

# rpn
self.rpn_net = nn.Conv2d(self._net_conv_channels, cfg.RPN_CHANNELS, [3, 3], padding=1)
self.rpn_cls_score_net = nn.Conv2d(cfg.RPN_CHANNELS, self._num_anchors * 2, [1, 1])
self.rpn_bbox_pred_net = nn.Conv2d(cfg.RPN_CHANNELS, self._num_anchors * 4, [1, 1])rpn = F.relu(self.rpn_net(net_conv))
rpn_cls_score = self.rpn_cls_score_net(rpn) # batch * (num_anchors * 2) * h * w
rpn_bbox_pred = self.rpn_bbox_pred_net(rpn) # batch * (num_anchors * 4) * h * w

anchor target

对上一步产生的anchor分配target label,1前景or0背景or-1忽略,以便训练rpn(只有分配了label的才能计算loss,即参与训练)
无NMS
1. 对于每个gt box,找到与他iou最大的anchor然后设为正样本
2. 对于每个anchor只要它与任意一个gt box iou>0.7即设为正样本
3. 对于每个anchor它与任意一个gt box iou都<0.3即设为负样本
4. 不是正也不是负的anchor被忽略

注意
正样本的数量由num_fg = int(cfg.TRAIN.RPN_FG_FRACTION * cfg.TRAIN.RPN_BATCHSIZE)控制,默认是256*0.5=128,即最多有128个正样本参与rpn的训练. 假如正样本有1234个,则随机抽1234-128个正样本将其label设置为-1,即忽略掉,当然正样本也有可能不足128个,那就都保留下来.
负样本的数量由num_bg = cfg.TRAIN.RPN_BATCHSIZE - np.sum(labels == 1),同理如果超额也为多余的忽略.
TRAIN.RPN_FG_FRACTION控制参与rpn训练的正样本的数量

注意在RPN阶段需要的配置参数都有RPN前缀,与后面的fast rcnn的参数区别开

# Max number of foreground examples
# __C.TRAIN.RPN_FG_FRACTION = 0.5
# Total number of examples
#__C.TRAIN.RPN_BATCHSIZE = 256# subsample positive labels if we have too many
num_fg = int(cfg.TRAIN.RPN_FG_FRACTION * cfg.TRAIN.RPN_BATCHSIZE)
fg_inds = np.where(labels == 1)[0]
if len(fg_inds) > num_fg:disable_inds = npr.choice(fg_inds, size=(len(fg_inds) - num_fg), replace=False)labels[disable_inds] = -1# subsample negative labels if we have too many
num_bg = cfg.TRAIN.RPN_BATCHSIZE - np.sum(labels == 1)
bg_inds = np.where(labels == 0)[0]
if len(bg_inds) > num_bg:disable_inds = npr.choice(bg_inds, size=(len(bg_inds) - num_bg), replace=False)labels[disable_inds] = -1

Fast RCNN

proposal

对RPN产生的anchor进行处理,有NMS
1. 首先利用4个坐标回归值对默认的w*h*9个anchor进行坐标变换生成proposal
2. 然后利用前景概率对这些proposal进行降序排列,然后留下RPN_PRE_NMS_TOP_N个proposal 训练是留下12000,测试是留下6000
3. 对剩下的proposal进行NMS处理,阈值是0.7
4. 对于剩下的proposal,只留下RPN_POST_NMS_TOP_N,训练是2000,测试是300
最终剩下的proposal即为rois了

proposal target

对留下的proposal(train:2000, test没有这个阶段,因为测试不知道gt无法分配)分配target label,属于具体哪一个类别,以便训练后面的分类器, 下面以train阶段的某个图片为例即该张图片有2000个proposal,gt中含有15个类别的box(不含背景) (全库有20个类别)

# Minibatch size (number of regions of interest [ROIs])
# __C.TRAIN.BATCH_SIZE = 128
# Fraction of minibatch that is labeled foreground (i.e. class > 0)
# __C.TRAIN.FG_FRACTION = 0.25 控制fast rcnn中rois的正负样本比例为1:3
num_images = 1
rois_per_image = cfg.TRAIN.BATCH_SIZE / num_images # 默认为128
fg_rois_per_image = int(round(cfg.TRAIN.FG_FRACTION * rois_per_image))  # 0.25*128
  1. 计算每个roi(proposal)与15个gt box做iou,得到overlaps(2000, 15) ,然后选择最大的iou作为这个roi的gt label(坑点: gt box的顺序不一定和label对应,一定要取gt box的第4个维度作为label,因为可能包含15个gt box,但是全库是有20中label的)
  2. 然后记roi与其target label的ovlap>TRAIN.FG_THRESH(0.5)的为fg,0.1
if fg_inds.numel() > 0 and bg_inds.numel() > 0:fg_rois_per_image = min(fg_rois_per_image, fg_inds.numel())fg_inds = fg_inds[torch.from_numpy(npr.choice(np.arange(0, fg_inds.numel()), size=int(fg_rois_per_image), replace=False)).long().cuda()]
#  ......
#  主要解读npr.choice(np.arange(0, fg_inds.numel()), size=int(fg_rois_per_image), replace=False)
#  在np.arange(0, fg_inds.numel())随机取int(fg_rois_per_image)个数,replace=False不允许重复

roi pooling

上一步得到了很多大小不一的roi,对应到feature map上也是大小不一的,但是fc是需要fixed size的,于是根据SPPNet论文笔记和caffe实现说明,出来了roi pooling(spp poolingfroze 前面的卷积只更新后面的fc,why见fast rcnn的2.3段解释的)
我主要参考了这篇博客Region of interest pooling explained,但是我感觉它的示意图是有问题的,应该有overlap的

1. 我们首先根据feature map和原图的比例,把roi在原图上的坐标映射到feature map上, 然后扣出roi对应部分的feature(蓝色框为实际位置,浮点坐标(1.2,0.8)(7.2,9.7),四舍五入量化到红色框(1,1)(7,10))

int roi_start_w = round(rois_flat[index_roi + 1] * spatial_scale);  // spatial_scale 1/16
int roi_start_h = round(rois_flat[index_roi + 2] * spatial_scale);
int roi_end_w = round(rois_flat[index_roi + 3] * spatial_scale);
int roi_end_h = round(rois_flat[index_roi + 4] * spatial_scale);

这里写图片描述

2. 对红色红色框进行roipooling

float bin_size_h = (float)(roi_height) / (float)(pooled_height);  // 9/7
float bin_size_w = (float)(roi_width) / (float)(pooled_width);  // 7/7=1
for (ph = 0; ph < pooled_height; ++ph){for (pw = 0; pw < pooled_width; ++pw){int hstart = (floor((float)(ph) * bin_size_h));  int wstart = (floor((float)(pw) * bin_size_w));int hend = (ceil((float)(ph + 1) * bin_size_h));int wend = (ceil((float)(pw + 1) * bin_size_w));hstart = fminf(fmaxf(hstart + roi_start_h, 0), data_height);hend = fminf(fmaxf(hend + roi_start_h, 0), data_height);wstart = fminf(fmaxf(wstart + roi_start_w, 0), data_width);wend = fminf(fmaxf(wend + roi_start_w, 0), data_width);
// ......
// 经过计算后w步长为1,窗口为1,没有overlap,h窗口步长不定都有overlap,注意在ph=3时窗口为3了
// 注意边界 pw=pooled_width-1时 wend=(ceil((float)(pw + 1) * bin_size_w))
//  =(ceil((float)pooled_width * (float)(roi_width) / (float)
//  =(pooled_width)))=ceil(roi_width)=roi_width
//  刚好把所有roi对应的feature map覆盖完,hend同理
//  roi_height roi_width小于pooled_height pooled_width时overlap就多一点呗

这里写图片描述

3. 对每个划分的pool bin进行max或者average pooling最后得到7*7的feature map

分类和回归

roi pooling后就得到fixed size的feature map(7*7),然后送入cls_score_net得到分类,送入bbox_pred_net粗暴的坐标回归和rpn时一样

self.cls_score_net = nn.Linear(self._fc7_channels, self._num_classes)
self.bbox_pred_net = nn.Linear(self._fc7_channels, self._num_classes * 4)

Loss采用smooth L1 Loss(和fast rcnn一致,rcnn采用的是L2 Loss)。
这里写图片描述

prior_centers = center_size(prior_boxes) #(cx, cy, w, h)
gt_centers = center_size(gt_boxes) #(cx, cy, w, h)
# tx=(gx-px)/pw  gx是gt的中心坐标x,px是proposal的中心坐标x,pw是预测的宽。ty同理
center_targets = (gt_centers[:, :2] - prior_centers[:, :2]) / prior_centers[:, 2:]
# 参照tw,th的公式
# 加了log, 降低w,h产生的loss的数量级, 让它在loss里占的比重小些, 不至于因为w,h的loss太大而让x,y产生的loss无用
# 因为若是x,y没预测准确, w,h再准确也没有用. 
size_targets = torch.log(gt_centers[:, 2:]) - torch.log(prior_centers[:, 2:])
all_targets = torch.cat((center_targets, size_targets), 1)
loss = F.smooth_l1_loss(deltas, all_targets, size_average=False)/(eps + prior_centers.size(0))

smoothed L1 Loss is a robust L1 loss that is less sensitive to outliers than the L2 loss used in R-CNN and SPPnet.

上述是Fast RCNN解释为什么采用smoothed L1, 因为它对噪音点不那么敏感,即对离目标太远的点不敏感。因为L2loss求导后 0.5*(t-v)^2 求导-> (t-v) 会有一个(t-v) 的系数在,如果v离t太远梯度很容易爆炸(需要精致地调节学习率),而smoothed L1中当|t-v|>1, |t-v|-0.5 求导-> 系数是±1, 这样就避免了梯度爆炸, 也就是它更加鲁棒。(t是target,v是需要预测出来的中心xy和尺寸wh)
这里写图片描述

测试

继续假设全部类别数是20种
1. 图片送入网络后前传,没有给anchor proposal指定gt的部分(忽略_anchor_target_layer _proposal_target_layer)
2. 经过proposal得到300个roi,经过cls_score_net bbox_pred_net得到每个roi在20个类别的置信度和4个坐标回归值(可在测试时把这个回归值用上,也可以不用)
3. 测试时300个roi类别未知,所以可以对应20个类别,即有300*20个box,300*20个置信度
3. 对每一类,取300个roi>thresh(默认为0.),然后进行nms获得留下的box
4. 然后对20类留下的所有box,按概率排序,留下设定的max_per_image个box
有个不解就是为什么对于每个roi,不是选择其置信度最大的类别,而可以对应到20种类别,可能是map算法,同等置信度下,多一些box得分会高一些

for j in range(1, imdb.num_classes):inds = np.where(scores[:, j] > thresh)[0]cls_scores = scores[inds, j]cls_boxes = boxes[inds, j*4:(j+1)*4]cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])) \.astype(np.float32, copy=False)keep = nms(torch.from_numpy(cls_dets), cfg.TEST.NMS).numpy() if cls_dets.size > 0 else []cls_dets = cls_dets[keep, :]all_boxes[j][i] = cls_dets

延伸

验证一下nms在训练时是不是必须的
参考An Implementation of Faster RCNN with Study for Region Sampling
这里写图片描述

• First, take the top K regions according to RPN score.
• Then, non-maximal suppression (NMS) with overlapping ratio of 0.7 is applied to perform de-duplication.
• Third, top k regions are selected as RoIs.
Intuitively, it is more likely for large regions to overlap than small regions, so large regions have a higher chance to be suppressed对这句话保留意见,nms算的是iou,没有偏向抑制大的region吧
ALL是top12000 proposal都送入后面的网络,不进行nms PRE是利用第一行已经训练好的faster rcnn直接得到最终的正负样本比例 POW: 比例和scale成反比,详细见文章。TOP是test是选择top 5000不进行nms(faster rcnn本身是选择top 6000然后nms,最后再取top300)

In fact, we find this advantage of TOP over NMS consistently exists when K is sufficiently large.

这篇关于faster rcnn源码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1103212

相关文章

使用Python实现批量访问URL并解析XML响应功能

《使用Python实现批量访问URL并解析XML响应功能》在现代Web开发和数据抓取中,批量访问URL并解析响应内容是一个常见的需求,本文将详细介绍如何使用Python实现批量访问URL并解析XML响... 目录引言1. 背景与需求2. 工具方法实现2.1 单URL访问与解析代码实现代码说明2.2 示例调用

SSID究竟是什么? WiFi网络名称及工作方式解析

《SSID究竟是什么?WiFi网络名称及工作方式解析》SID可以看作是无线网络的名称,类似于有线网络中的网络名称或者路由器的名称,在无线网络中,设备通过SSID来识别和连接到特定的无线网络... 当提到 Wi-Fi 网络时,就避不开「SSID」这个术语。简单来说,SSID 就是 Wi-Fi 网络的名称。比如

SpringCloud配置动态更新原理解析

《SpringCloud配置动态更新原理解析》在微服务架构的浩瀚星海中,服务配置的动态更新如同魔法一般,能够让应用在不重启的情况下,实时响应配置的变更,SpringCloud作为微服务架构中的佼佼者,... 目录一、SpringBoot、Cloud配置的读取二、SpringCloud配置动态刷新三、更新@R

使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)

《使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)》在现代软件开发中,处理JSON数据是一项非常常见的任务,无论是从API接口获取数据,还是将数据存储为JSON格式,解析... 目录1. 背景介绍1.1 jsON简介1.2 实际案例2. 准备工作2.1 环境搭建2.1.1 添加

Java汇编源码如何查看环境搭建

《Java汇编源码如何查看环境搭建》:本文主要介绍如何在IntelliJIDEA开发环境中搭建字节码和汇编环境,以便更好地进行代码调优和JVM学习,首先,介绍了如何配置IntelliJIDEA以方... 目录一、简介二、在IDEA开发环境中搭建汇编环境2.1 在IDEA中搭建字节码查看环境2.1.1 搭建步

在C#中合并和解析相对路径方式

《在C#中合并和解析相对路径方式》Path类提供了几个用于操作文件路径的静态方法,其中包括Combine方法和GetFullPath方法,Combine方法将两个路径合并在一起,但不会解析包含相对元素... 目录C#合并和解析相对路径System.IO.Path类幸运的是总结C#合并和解析相对路径对于 C

Java解析JSON的六种方案

《Java解析JSON的六种方案》这篇文章介绍了6种JSON解析方案,包括Jackson、Gson、FastJSON、JsonPath、、手动解析,分别阐述了它们的功能特点、代码示例、高级功能、优缺点... 目录前言1. 使用 Jackson:业界标配功能特点代码示例高级功能优缺点2. 使用 Gson:轻量

Java如何接收并解析HL7协议数据

《Java如何接收并解析HL7协议数据》文章主要介绍了HL7协议及其在医疗行业中的应用,详细描述了如何配置环境、接收和解析数据,以及与前端进行交互的实现方法,文章还分享了使用7Edit工具进行调试的经... 目录一、前言二、正文1、环境配置2、数据接收:HL7Monitor3、数据解析:HL7Busines

python解析HTML并提取span标签中的文本

《python解析HTML并提取span标签中的文本》在网页开发和数据抓取过程中,我们经常需要从HTML页面中提取信息,尤其是span元素中的文本,span标签是一个行内元素,通常用于包装一小段文本或... 目录一、安装相关依赖二、html 页面结构三、使用 BeautifulSoup javascript

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库