YOLOv8改进 | 2023 | 给YOLOv8换个RT-DETR的检测头(重塑目标检测前沿技术)

2023-12-03 07:20

本文主要是介绍YOLOv8改进 | 2023 | 给YOLOv8换个RT-DETR的检测头(重塑目标检测前沿技术),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、本文介绍

本文给大家带来是用最新的RT-DETR模型的检测头去替换YOLOv8中的检测头。RT-DETR号称是打败YOLO的检测模型,其作为一种基于Transformer的检测方法,相较于传统的基于卷积的检测方法,提供了更为全面和深入的特征理解,将RT-DETR检测头融入YOLOv8,我们可以结合YOLO的实时检测能力和RT-DETR的深度特征理解能力,打造出一个更加强大的目标检测模型。亲测这一整合不仅提高了模型在复杂场景下的表现,还显著提升了对小目标和遮挡目标的检测能力。此外,模型在多种标准数据集上的表现也有了明显提升,特别是在处理高动态范围和复杂背景的图像时,其表现尤为出色。

适用检测目标:需要注意的是本文的改进并不一定适合所有的数据集可能只有在部分的数据集有效(听人反馈部分的数据集可能掉点)。

推荐指数:⭐⭐⭐

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备    

训练结果对比图->  

因为资源有限我发的文章都要做对比实验所以本次实验我只用了一百张图片检测的是火灾训练了二百个epoch,该结果只能展示出该机制有效,但是并不能产生决定性结果,因为具体的效果还要看你的数据集和实验环境所影响

可以看出部分的检测有提点但是有的有降点所以具体的效果还要看你自己的数据集来看。 

 

目录

一、本文介绍

二、RT-DETR检测头框架原理

 2.1 RT-DETR的基本原理

三、RT-DETR检测头的代码 

四、手把手教你添加RT-DETR检测头

4.1 RT-DETR的添加教程

4.2 的yaml文件和训练截图

4.2.1 RT-DETR检测头的yaml文件

4.3 训练方法 

4.3.1 RT-DETR的训练过程截图 

六、本文总结


二、RT-DETR检测头框架原理

论文地址:RT-DETR论文地址

代码地址:RT-DETR官方下载地址

 2.1 RT-DETR的基本原理

RT-DETR系统中的检测头变换器解码器(transformer decoder)部分,包括辅助预测头,是该系统的核心组成之一。变换器解码器在RT-DETR中扮演了重要角色,主要负责处理经过混合编码器加工后的特征,并对这些特征进行目标检测。这一部分的设计是基于Transformer架构的,该架构已在自然语言处理领域取得了巨大成功,并在最近几年逐渐被应用于计算机视觉任务中。

在RT-DETR中,变换器解码器利用Transformer的强大能力来捕捉和理解图像中的复杂模式。它通过自注意力机制来分析和加权输入特征,从而能够更准确地定位和识别图像中的不同对象。这种方法使RT-DETR能够在保持高速度的同时,提供比传统方法更精准的检测结果。

辅助预测头则是变换器解码器的一个关键组件,它被用来进一步提升模型的检测性能。这些预测头直接作用于解码器的输出,负责生成最终的检测结果,包括对象的类别、位置和大小等信息。通过这种设计,RT-DETR能够在端到端的框架内完成整个目标检测流程,无需依赖于传统的基于锚点的方法或复杂的后处理步骤,这大大简化了检测流程并提高了效率。

RT-DETR网络详解:详解RT-DETR网络结构/数据集获取/环境搭建/训练/推理/验证/导出/部署

 


三、RT-DETR检测头的代码 

以下的代码时RT-DETR检测头的代码,如果你的YOLO仓库已经是新版本的了(好像是9月份的时候更新的)就已经集成这个代码了无需做任何修改即可使用该检测头,但是该检测头有一些使用注意事项后面第四章我会提到。

class RTDETRDecoder(nn.Module):"""Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxesand class labels for objects in an image. It integrates features from multiple layers and runs through a series ofTransformer decoder layers to output the final predictions."""export = False  # export modedef __init__(self,nc=80,ch=(512, 1024, 2048),hd=256,  # hidden dimnq=300,  # num queriesndp=4,  # num decoder pointsnh=8,  # num headndl=6,  # num decoder layersd_ffn=1024,  # dim of feedforwarddropout=0.,act=nn.ReLU(),eval_idx=-1,# Training argsnd=100,  # num denoisinglabel_noise_ratio=0.5,box_noise_scale=1.0,learnt_init_query=False):"""Initializes the RTDETRDecoder module with the given parameters.Args:nc (int): Number of classes. Default is 80.ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048).hd (int): Dimension of hidden layers. Default is 256.nq (int): Number of query points. Default is 300.ndp (int): Number of decoder points. Default is 4.nh (int): Number of heads in multi-head attention. Default is 8.ndl (int): Number of decoder layers. Default is 6.d_ffn (int): Dimension of the feed-forward networks. Default is 1024.dropout (float): Dropout rate. Default is 0.act (nn.Module): Activation function. Default is nn.ReLU.eval_idx (int): Evaluation index. Default is -1.nd (int): Number of denoising. Default is 100.label_noise_ratio (float): Label noise ratio. Default is 0.5.box_noise_scale (float): Box noise scale. Default is 1.0.learnt_init_query (bool): Whether to learn initial query embeddings. Default is False."""super().__init__()self.hidden_dim = hdself.nhead = nhself.nl = len(ch)  # num levelself.nc = ncself.num_queries = nqself.num_decoder_layers = ndl# Backbone feature projectionself.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)# NOTE: simplified version but it's not consistent with .pt weights.# self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)# Transformer moduledecoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)# Denoising partself.denoising_class_embed = nn.Embedding(nc, hd)self.num_denoising = ndself.label_noise_ratio = label_noise_ratioself.box_noise_scale = box_noise_scale# Decoder embeddingself.learnt_init_query = learnt_init_queryif learnt_init_query:self.tgt_embed = nn.Embedding(nq, hd)self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)# Encoder headself.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))self.enc_score_head = nn.Linear(hd, nc)self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)# Decoder headself.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])self._reset_parameters()def forward(self, x, batch=None):"""Runs the forward pass of the module, returning bounding box and classification scores for the input."""from ultralytics.models.utils.ops import get_cdn_group# Input projection and embeddingfeats, shapes = self._get_encoder_input(x)# Prepare denoising trainingdn_embed, dn_bbox, attn_mask, dn_meta = \get_cdn_group(batch,self.nc,self.num_queries,self.denoising_class_embed.weight,self.num_denoising,self.label_noise_ratio,self.box_noise_scale,self.training)embed, refer_bbox, enc_bboxes, enc_scores = \self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)# Decoderdec_bboxes, dec_scores = self.decoder(embed,refer_bbox,feats,shapes,self.dec_bbox_head,self.dec_score_head,self.query_pos_head,attn_mask=attn_mask)x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_metaif self.training:return x# (bs, 300, 4+nc)y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)return y if self.export else (y, x)def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):"""Generates anchor bounding boxes for given shapes with specific grid size and validates them."""anchors = []for i, (h, w) in enumerate(shapes):sy = torch.arange(end=h, dtype=dtype, device=device)sx = torch.arange(end=w, dtype=dtype, device=device)grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)grid_xy = torch.stack([grid_x, grid_y], -1)  # (h, w, 2)valid_WH = torch.tensor([h, w], dtype=dtype, device=device)grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH  # (1, h, w, 2)wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4))  # (1, h*w, 4)anchors = torch.cat(anchors, 1)  # (1, h*w*nl, 4)valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)  # 1, h*w*nl, 1anchors = torch.log(anchors / (1 - anchors))anchors = anchors.masked_fill(~valid_mask, float('inf'))return anchors, valid_maskdef _get_encoder_input(self, x):"""Processes and returns encoder inputs by getting projection features from input and concatenating them."""# Get projection featuresx = [self.input_proj[i](feat) for i, feat in enumerate(x)]# Get encoder inputsfeats = []shapes = []for feat in x:h, w = feat.shape[2:]# [b, c, h, w] -> [b, h*w, c]feats.append(feat.flatten(2).permute(0, 2, 1))# [nl, 2]shapes.append([h, w])# [b, h*w, c]feats = torch.cat(feats, 1)return feats, shapesdef _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):"""Generates and prepares the input required for the decoder from the provided features and shapes."""bs = len(feats)# Prepare input for decoderanchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)features = self.enc_output(valid_mask * feats)  # bs, h*w, 256enc_outputs_scores = self.enc_score_head(features)  # (bs, h*w, nc)# Query selection# (bs, num_queries)topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)# (bs, num_queries)batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)# (bs, num_queries, 256)top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)# (bs, num_queries, 4)top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)# Dynamic anchors + static contentrefer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchorsenc_bboxes = refer_bbox.sigmoid()if dn_bbox is not None:refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_featuresif self.training:refer_bbox = refer_bbox.detach()if not self.learnt_init_query:embeddings = embeddings.detach()if dn_embed is not None:embeddings = torch.cat([dn_embed, embeddings], 1)return embeddings, refer_bbox, enc_bboxes, enc_scores# TODOdef _reset_parameters(self):"""Initializes or resets the parameters of the model's various components with predefined weights and biases."""# Class and bbox head initbias_cls = bias_init_with_prob(0.01) / 80 * self.nc# NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.# linear_init_(self.enc_score_head)constant_(self.enc_score_head.bias, bias_cls)constant_(self.enc_bbox_head.layers[-1].weight, 0.)constant_(self.enc_bbox_head.layers[-1].bias, 0.)for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):# linear_init_(cls_)constant_(cls_.bias, bias_cls)constant_(reg_.layers[-1].weight, 0.)constant_(reg_.layers[-1].bias, 0.)linear_init_(self.enc_output[0])xavier_uniform_(self.enc_output[0].weight)if self.learnt_init_query:xavier_uniform_(self.tgt_embed.weight)xavier_uniform_(self.query_pos_head.layers[0].weight)xavier_uniform_(self.query_pos_head.layers[1].weight)for layer in self.input_proj:xavier_uniform_(layer[0].weight)

 


四、手把手教你添加RT-DETR检测头

4.1 RT-DETR的添加教程

这个检测头如果你已经是最新版本的YOLO了那么已经集成在你的代码里了,而且更换了该检测头需要用RT-DETR的训练方法(用YOLO版本的训练方法会报错)。同时如果你使用了该检测头需要增大训练epochs,比如你原先150能够收敛模型那么改完之后可能需要200-250才能够收敛。

具体的训练方法需要修改完下一小节的配置文件后看4.3章节。


4.2 的yaml文件和训练截图

4.2.1 RT-DETR检测头的yaml文件

需要把YOLO的检测头更换为RT-DETR的检测头修改如下->

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[15, 18, 21], 1, RTDETRDecoder, [nc]]  # Detect(P3, P4, P5)

 


4.3 训练方法 

前面提到了修改完之后进行训练和YOLO变得不在一样需要按照RT-DETR的训练方式(否则会报错)我使用的新建文件进行训练,可以新建一个run.py文件然后把以下代码复制把你的数据集更新进去即可进行训练。

import warnings
warnings.filterwarnings('ignore')
from ultralytics import RTDETRif __name__ == '__main__':model = RTDETR('你替换了RT-DETR检测头的yaml文件地址')model.train(data='替换你数据集的yaml文件地址',imgsz=640,epochs=200,batch=16,workers=0,device=0,optimizer='SGD', # 这里可以使用两个优化器SGD 和AdamW其它的可能会导致模型无法收敛)

 


4.3.1 RT-DETR的训练过程截图 

下面是添加了RT-DETR的训练截图。

(最近有人说我改的代码是没有发全的,我不知道这群人是怎么说出这种话的,希望大家如果用我的代码成功的可以在评论区支持一下,我也好发更多的改进毕竟免费给大家看。同时有问题皆可在评论区留言我看到都会回复) 

​​​

 


六、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备

3d51a0611af1442f833362eaf18fbae2.gif

 

 

 

 

这篇关于YOLOv8改进 | 2023 | 给YOLOv8换个RT-DETR的检测头(重塑目标检测前沿技术)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/448598

相关文章

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖

烟火目标检测数据集 7800张 烟火检测 带标注 voc yolo

一个包含7800张带标注图像的数据集,专门用于烟火目标检测,是一个非常有价值的资源,尤其对于那些致力于公共安全、事件管理和烟花表演监控等领域的人士而言。下面是对此数据集的一个详细介绍: 数据集名称:烟火目标检测数据集 数据集规模: 图片数量:7800张类别:主要包含烟火类目标,可能还包括其他相关类别,如烟火发射装置、背景等。格式:图像文件通常为JPEG或PNG格式;标注文件可能为X

基于 YOLOv5 的积水检测系统:打造高效智能的智慧城市应用

在城市发展中,积水问题日益严重,特别是在大雨过后,积水往往会影响交通甚至威胁人们的安全。通过现代计算机视觉技术,我们能够智能化地检测和识别积水区域,减少潜在危险。本文将介绍如何使用 YOLOv5 和 PyQt5 搭建一个积水检测系统,结合深度学习和直观的图形界面,为用户提供高效的解决方案。 源码地址: PyQt5+YoloV5 实现积水检测系统 预览: 项目背景

JavaFX应用更新检测功能(在线自动更新方案)

JavaFX开发的桌面应用属于C端,一般来说需要版本检测和自动更新功能,这里记录一下一种版本检测和自动更新的方法。 1. 整体方案 JavaFX.应用版本检测、自动更新主要涉及一下步骤: 读取本地应用版本拉取远程版本并比较两个版本如果需要升级,那么拉取更新历史弹出升级控制窗口用户选择升级时,拉取升级包解压,重启应用用户选择忽略时,本地版本标志为忽略版本用户选择取消时,隐藏升级控制窗口 2.

CSP 2023 提高级第一轮 CSP-S 2023初试题 完善程序第二题解析 未完

一、题目阅读 (最大值之和)给定整数序列 a0,⋯,an−1,求该序列所有非空连续子序列的最大值之和。上述参数满足 1≤n≤105 和 1≤ai≤108。 一个序列的非空连续子序列可以用两个下标 ll 和 rr(其中0≤l≤r<n0≤l≤r<n)表示,对应的序列为 al,al+1,⋯,ar​。两个非空连续子序列不同,当且仅当下标不同。 例如,当原序列为 [1,2,1,2] 时,要计算子序列 [

一种改进的red5集群方案的应用、基于Red5服务器集群负载均衡调度算法研究

转自: 一种改进的red5集群方案的应用: http://wenku.baidu.com/link?url=jYQ1wNwHVBqJ-5XCYq0PRligp6Y5q6BYXyISUsF56My8DP8dc9CZ4pZvpPz1abxJn8fojMrL0IyfmMHStpvkotqC1RWlRMGnzVL1X4IPOa_  基于Red5服务器集群负载均衡调度算法研究 http://ww

HNU-2023电路与电子学-实验3

写在前面: 一、实验目的 1.了解简易模型机的内部结构和工作原理。 2.分析模型机的功能,设计 8 重 3-1 多路复用器。 3.分析模型机的功能,设计 8 重 2-1 多路复用器。 4.分析模型机的工作原理,设计模型机控制信号产生逻辑。 二、实验内容 1.用 VERILOG 语言设计模型机的 8 重 3-1 多路复用器; 2.用 VERILOG 语言设计模型机的 8 重 2-1 多

[数据集][目标检测]血细胞检测数据集VOC+YOLO格式2757张4类别

数据集格式:Pascal VOC格式+YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):2757 标注数量(xml文件个数):2757 标注数量(txt文件个数):2757 标注类别数:4 标注类别名称:["Platelets","RBC","WBC","sickle cell"] 每个类别标注的框数:

Temu官方宣导务必将所有的点位材料进行检测-RSL资质检测

关于饰品类产品合规问题宣导: 产品法规RSL要求 RSL测试是根据REACH法规及附录17的要求进行测试。REACH法规是欧洲一项重要的法规,其中包含许多对化学物质进行限制的规定和高度关注物质。 为了确保珠宝首饰的安全性,欧盟REACH法规规定,珠宝首饰上架各大电商平台前必须进行RSLReport(欧盟禁限用化学物质检测报告)资质认证,以确保产品不含对人体有害的化学物质。 RSL-铅,

YOLOv8/v10+DeepSORT多目标车辆跟踪(车辆检测/跟踪/车辆计数/测速/禁停区域/绘制进出线/绘制禁停区域/车道车辆统计)

01:YOLOv8 + DeepSort 车辆跟踪 该项目利用YOLOv8作为目标检测模型,DeepSort用于多目标跟踪。YOLOv8负责从视频帧中检测出车辆的位置,而DeepSort则负责关联这些检测结果,从而实现车辆的持续跟踪。这种组合使得系统能够在视频流中准确地识别并跟随特定车辆。 02:YOLOv8 + DeepSort 车辆跟踪 + 任意绘制进出线 在此基础上增加了用户