mmdetection - anchor-based方法训练流程解析

2024-04-24 11:08

本文主要是介绍mmdetection - anchor-based方法训练流程解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

训练流程图
在这里插入图片描述
最终会创建一个runner,然后调用runner.run时,实际会根据workflow中是train还是val,调用runner.py下的train和val函数。
batch_processor

def batch_processor(model, data, train_mode):# 这里的train_mode实际没用到losses = model(**data)loss, log_vars = parse_losses(losses)outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))return outputs

mmcv/runner/runner.py
train

def train(self, data_loader, **kwargs):self.model.train()self.mode = 'train'self.data_loader = data_loaderself._max_iters = self._max_epochs * len(data_loader)self.call_hook('before_train_epoch')for i, data_batch in enumerate(data_loader):self._inner_iter = iself.call_hook('before_train_iter')outputs = self.batch_processor(self.model, data_batch, train_mode=True, **kwargs)if not isinstance(outputs, dict):raise TypeError('batch_processor() must return a dict')if 'log_vars' in outputs:self.log_buffer.update(outputs['log_vars'],outputs['num_samples'])self.outputs = outputsself.call_hook('after_train_iter')self._iter += 1self.call_hook('after_train_epoch')self._epoch += 1

val

def val(self, data_loader, **kwargs):self.model.eval()self.mode = 'val'self.data_loader = data_loaderself.call_hook('before_val_epoch')for i, data_batch in enumerate(data_loader):self._inner_iter = iself.call_hook('before_val_iter')with torch.no_grad():outputs = self.batch_processor(self.model, data_batch, train_mode=False, **kwargs)if not isinstance(outputs, dict):raise TypeError('batch_processor() must return a dict')if 'log_vars' in outputs:self.log_buffer.update(outputs['log_vars'],outputs['num_samples'])self.outputs = outputsself.call_hook('after_val_iter')self.call_hook('after_val_epoch')

validate目前只在_dist_train中有用到

训练时,实际调用:losses = model(**data),验证时,实际调用hook,运行:

with torch.no_grad():result = runner.model(return_loss=False, rescale=True, **data_gpu)

其中,TwoStageDetector和SingleStageDetector都继承了BaseDetector,在BaseDetector中,forward函数定义如下:

@auto_fp16(apply_to=('img', ))
def forward(self, img, img_meta, return_loss=True, **kwargs):if return_loss:return self.forward_train(img, img_meta, **kwargs)else:return self.forward_test(img, img_meta, **kwargs)

对于forward_test,其代码如下:

def forward_test(self, imgs, img_metas, **kwargs):for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:if not isinstance(var, list):raise TypeError('{} must be a list, but got {}'.format(name, type(var)))num_augs = len(imgs)if num_augs != len(img_metas):raise ValueError('num of augmentations ({}) != num of image meta ({})'.format(len(imgs), len(img_metas)))# TODO: remove the restriction of imgs_per_gpu == 1 when preparedimgs_per_gpu = imgs[0].size(0)assert imgs_per_gpu == 1if num_augs == 1:return self.simple_test(imgs[0], img_metas[0], **kwargs)else:return self.aug_test(imgs, img_metas, **kwargs)

由上可以看出,子类需要写simple_test和aub_test函数。
对于一个检测模型(一阶或者二阶),在其class中,需要重写以下函数:

  • forward_train
  • simple_test
  • aug_test # 非必须

下面以retinanet举个例子,在retinanet的config文件中,model的type是RetinaNet,在mmdet/models/detectors/retinanet.py中,定义了RetinaNet,它的父类是SingleStageDetector,定义在mmdet/models/detectors/single_stage.py中,三个重要函数的代码如下:

def forward_train(self,img,img_metas,gt_bboxes,gt_labels,gt_bboxes_ignore=None):x = self.extract_feat(img)outs = self.bbox_head(x)loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg)losses = self.bbox_head.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)return lossesdef simple_test(self, img, img_meta, rescale=False):x = self.extract_feat(img)outs = self.bbox_head(x)bbox_inputs = outs + (img_meta, self.test_cfg, rescale)bbox_list = self.bbox_head.get_bboxes(*bbox_inputs)bbox_results = [bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)for det_bboxes, det_labels in bbox_list]return bbox_results[0]def aug_test(self, imgs, img_metas, rescale=False):raise NotImplementedError

由上可知,计算loss的函数是在head中定义的,RetinaHead定义在mmdet/models/anchor_heads/retina_head.py中,RetinaHead三个关键函数的代码如下:

def _init_layers(self):self.relu = nn.ReLU(inplace=True)self.cls_convs = nn.ModuleList()self.reg_convs = nn.ModuleList()for i in range(self.stacked_convs):chn = self.in_channels if i == 0 else self.feat_channelsself.cls_convs.append(ConvModule(chn,self.feat_channels,3,stride=1,padding=1,conv_cfg=self.conv_cfg,norm_cfg=self.norm_cfg))self.reg_convs.append(ConvModule(chn,self.feat_channels,3,stride=1,padding=1,conv_cfg=self.conv_cfg,norm_cfg=self.norm_cfg))self.retina_cls = nn.Conv2d(self.feat_channels,self.num_anchors * self.cls_out_channels,3,padding=1)self.retina_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 3, padding=1)def init_weights(self):for m in self.cls_convs:normal_init(m.conv, std=0.01)for m in self.reg_convs:normal_init(m.conv, std=0.01)bias_cls = bias_init_with_prob(0.01)normal_init(self.retina_cls, std=0.01, bias=bias_cls)normal_init(self.retina_reg, std=0.01)def forward_single(self, x):cls_feat = xreg_feat = xfor cls_conv in self.cls_convs:cls_feat = cls_conv(cls_feat)for reg_conv in self.reg_convs:reg_feat = reg_conv(reg_feat)cls_score = self.retina_cls(cls_feat)bbox_pred = self.retina_reg(reg_feat)return cls_score, bbox_pred

其中,_init_layers创建head的结构,init_weights对conv的weight和bias做初始化,forward_single是经过head计算得到的分类和检测框预测结果。
forward
在具体的方法对应的head定义forward_single,最后由anchor_head.py中的forward函数进行组装。

from six.moves import map, zip
def multi_apply(func, *args, **kwargs):pfunc = partial(func, **kwargs) if kwargs else func # 将func的kwargs固定,返回该函数# 这里的*args=feats,调用forward_single对feats的元素依次跑前向map_results = map(pfunc, *args) # 得到[(stride1_cls,stride1_bbox,...), (stride2_cls,stride2_bbox, ...]return tuple(map(list, zip(*map_results)))# zip(*map_results) 得到 [(stride1_cls,stride2_cls,stride3_cls,...),(stride1_bbox,stride2_bbox,stride3_bbox,...)]# map(list, zip(*map_results)) 将(stride1_cls,stride2_cls,stride3_cls,...)变为[stride1_cls,stride2_cls,stride3_cls,...]# tuple之后,最后得到([stride1_cls,stride2_cls,stride3_cls,...],[stride1_bbox,stride2_bbox,stride3_bbox,...])def forward(self, feats):# 输入feats是一个list,长度为stride个数,其中元素为nchwreturn multi_apply(self.forward_single, feats)def forward_single(self, x):# 这里的x为feats中的某一个元素cls_feat = xreg_feat = xfor cls_conv in self.cls_convs:cls_feat = cls_conv(cls_feat)for reg_conv in self.reg_convs:reg_feat = reg_conv(reg_feat)cls_score = self.retina_cls(cls_feat)bbox_pred = self.retina_reg(reg_feat)return cls_score, bbox_pred

loss

这篇关于mmdetection - anchor-based方法训练流程解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/931582

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

Security OAuth2 单点登录流程

单点登录(英语:Single sign-on,缩写为 SSO),又译为单一签入,一种对于许多相互关连,但是又是各自独立的软件系统,提供访问控制的属性。当拥有这项属性时,当用户登录时,就可以获取所有系统的访问权限,不用对每个单一系统都逐一登录。这项功能通常是以轻型目录访问协议(LDAP)来实现,在服务器上会将用户信息存储到LDAP数据库中。相同的,单一注销(single sign-off)就是指

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

浅谈主机加固,六种有效的主机加固方法

在数字化时代,数据的价值不言而喻,但随之而来的安全威胁也日益严峻。从勒索病毒到内部泄露,企业的数据安全面临着前所未有的挑战。为了应对这些挑战,一种全新的主机加固解决方案应运而生。 MCK主机加固解决方案,采用先进的安全容器中间件技术,构建起一套内核级的纵深立体防护体系。这一体系突破了传统安全防护的局限,即使在管理员权限被恶意利用的情况下,也能确保服务器的安全稳定运行。 普适主机加固措施:

webm怎么转换成mp4?这几种方法超多人在用!

webm怎么转换成mp4?WebM作为一种新兴的视频编码格式,近年来逐渐进入大众视野,其背后承载着诸多优势,但同时也伴随着不容忽视的局限性,首要挑战在于其兼容性边界,尽管WebM已广泛适应于众多网站与软件平台,但在特定应用环境或老旧设备上,其兼容难题依旧凸显,为用户体验带来不便,再者,WebM格式的非普适性也体现在编辑流程上,由于它并非行业内的通用标准,编辑过程中可能会遭遇格式不兼容的障碍,导致操

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

【北交大信息所AI-Max2】使用方法

BJTU信息所集群AI_MAX2使用方法 使用的前提是预约到相应的算力卡,拥有登录权限的账号密码,一般为导师组共用一个。 有浏览器、ssh工具就可以。 1.新建集群Terminal 浏览器登陆10.126.62.75 (如果是1集群把75改成66) 交互式开发 执行器选Terminal 密码随便设一个(需记住) 工作空间:私有数据、全部文件 加速器选GeForce_RTX_2080_Ti

【VUE】跨域问题的概念,以及解决方法。

目录 1.跨域概念 2.解决方法 2.1 配置网络请求代理 2.2 使用@CrossOrigin 注解 2.3 通过配置文件实现跨域 2.4 添加 CorsWebFilter 来解决跨域问题 1.跨域概念 跨域问题是由于浏览器实施了同源策略,该策略要求请求的域名、协议和端口必须与提供资源的服务相同。如果不相同,则需要服务器显式地允许这种跨域请求。一般在springbo