RTDETR论文快速理解和代码快速实现(训练与预测)

2023-12-21 03:28

本文主要是介绍RTDETR论文快速理解和代码快速实现(训练与预测),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 前言
  • 一、摘要
  • 二、论文目的
  • 三、论文贡献
  • 四、模型结构
    • 1、模型整体结构
    • 2、backbone结构
    • 3、neck结构
    • 4、混合编码器(neck)
  • 五、RTDERT模型训练(data-->train)
    • 1、环境安装
    • 2、训练
      • 1、数据准备
      • 2、数据yaml文件
      • 3、训练代码
      • 4、训练运行结果
    • 3、推理
      • 1、推理代码
      • 2、推理运行结果
  • 总结


前言

最近,我们想比较基于DETR的transformer模型与基于CNN的yolo模型效果,而百度RT-DETR模型声称“在实时目标检测领域打败YOLO”。从数据的角度来看,RT-DETR似乎确实在某些方面超越了YOLO。我选择RT-DETR模型与YOLO模型比较。本篇文章将介绍RT-DETR模型原理–>环境安装–>数据准备–>训练实现–>预测实现。


一、摘要

近期,端到端基于transformer检测器DETRs已有显著性能。然而,DETR的计算成本限制其实际应用,也阻止其无后处理的优势(如:NMS)。在这篇论文,我们首次分析NMS对目标检测的速度与精确率影响,并构建了端到端的speed基准。为了解决这些问题,我们提出RT-DETR模型,据我们所知,这是第一个实时端到端检测模型。特别的,我们设计一个高效混合编码器加工多尺度特征与特征交互和融合,并提出IOU感知查询,通过像解码器提供更高初始目标来提示性能。除此之外,我们提出的检测模型,可使用解码层without retraining灵活调整推理速度,这样可适应多样的实时场景。我们模型RT-DETR-L在coco2017实现了53%的AP和114FPS on T4 gpu,而RT-DETR-X实现54.8%AP和74FPS,超过同规模模型的yolo。此外,我们的 RT-DETR-R50 实现了53.1%的AP和108FPS的速度,准确性优于 DINO-Deformable-DETR-R50 约 2.2% AP,帧率约为其21倍。
在这里插入图片描述

二、论文目的

实时目标检测是一个重要的研究领域,而DETR的高计算成本问题尚未得到有效解决,这限制了DETR的实际应用,并导致无法充分利用其优势(后处理)。换句话说,RTDETR解决问题是

为了实现上述目标,我们重新思考了DETR,并对其关键组件进行了详细分析和实验,以减少不必要的计算冗余。具体而言,我们发现虽然引入多尺度特征有助于加快训练收敛和提高性能[43],但它也导致输入编码器的序列长度显著增加。因此,由于高计算成本,Transformer编码器成为模型的计算瓶颈。为了实现实时目标检测,我们设计了一个高效的混合编码器来替代原始的Transformer编码器。通过解耦多尺度特征的内尺度交互和跨尺度融合,编码器能够高效处理具有不同尺度的特征。此外,先前的研究[35, 20]表明,解码器的对象查询初始化方案对于检测性能至关重要。为了进一步提高性能,我们提出了基于IoU的查询选择方法,通过在训练过程中提供IoU约束,为解码器提供更高质量的初始对象查询。此外,我们提出的检测器支持通过使用不同的解码器层对推理速度进行灵活调节,无需重新训练,这得益于DETR架构中解码器的设计,并有助于实时检测器的实际应用。

三、论文贡献

本论文的主要贡献总结如下:

1、我们提出了第一个实时端到端目标检测器,不仅在准确性和速度方面优于当前最先进的实时检测器,而且不需要后处理,因此推理速度不会延迟并保持稳定;

2、我们详细分析了NMS对实时检测器的影响,并从后处理的角度得出了关于基于CNN的实时检测器的结论;

3、我们提出的IoU-aware查询选择在模型中展现出卓越的性能改进,为改进目标查询的初始化方案提供了新的思路;

4、我们的工作为端到端检测器的实时实现提供了可行的解决方案,所提出的检测器可以通过使用不同的解码器层进行灵活调整模型大小和推理速度,无需重新训练。

四、模型结构

1、模型整体结构

RT-DETR模型由主干网络(backbone)、混合编码器(hybrid encoder)和带有辅助预测头的Transformer解码器组成。模型的整体架构概述如下图所示。具体来说,我们利用主干网络最后三个阶段的输出特征{S3,S4,S5}作为编码器的输入。混合编码器通过内部尺度交互和跨尺度融合,将多尺度特征转换为图像特征序列。随后,采用IoU感知的查询选择机制,从编码器的输出序列中选择固定数量的图像特征作为解码器的初始对象查询。最后,带有辅助预测头的解码器迭代优化对象查询,生成边界框和置信度分数。
在这里插入图片描述
RT-DETR模型架构图显示了主干网络的最后三个阶段{S3,S4,S5}作为编码器的输入。高效的混合编码器通过内部尺度特征交互(AIFI)和跨尺度特征融合模块(CCFM)将多尺度特征转化为图像特征序列。采用IoU感知的查询选择方法,选择固定数量的图像特征作为解码器的初始对象查询。最后,解码器通过辅助预测头迭代优化对象查询,生成边界框和置信度分数
本文最重要是设计AIFI与CCFM结构

2、backbone结构

与YOLO相似,RT-DETR最终会输出三种不同尺寸的特征图,它们相对于输入图像的分辨率下采样倍数分别是 8 倍、16 倍和 32 倍。这与主流的YOLO算法相似。除此之外,在主干结构的其他方面,RT-DETR并没有特别的地方。

3、neck结构

对于颈部网络部分,RT-DETR 采用了一层 Transformer 的 Encoder ,文中这个颈部网络叫做 Efficient Hybrid Encoder,其包括两部分:Attention-based Intra-scale Feature Interaction (AIFI) 和 CNN-based Cross-scale Feature-fusion Module (CCFM),这个AIFI模块有一点值得注意,这个模块只对S5特征图进行处理

对于AIFI模块(如下左图),它首先将二维的 S5 特征拉成向量,然后交给AIFI模块处理,其数学过程就是多头自注意力与 FFN,随后,再将输出Reshape回二维,记作 F5,以便去完成后续的所谓的“跨尺度特征融合”。

对于CCFM模块(如下右图),以YOLO的角度看这个结构的话,这个CCFM模块就是一个FPN/PAN结构。关于CCFM模块中的Fusion文中也给了详细的结构图,是由 2 个1×1 卷积和 N 个 RepBlock 构成的,这里之所以写成 N ,我觉得是因为 RT-DETR 可以进行缩放处理,通过调整 CCFM中RepBlock 的数量和 Encoder 的编码维度分别控制 Hybrid Encoder 的深度和宽度,同时对 backbone 进行相应的调整即可实现检测器的缩放。
在这里插入图片描述

4、混合编码器(neck)

在3已经介绍neck最终结构,而设计neck结构时,作者为了实时性与减少冗余,设计了一些列结构,其原因是注意力机制的改进减少了计算开销,却输入序列的大幅增加仍导致编码器成为计算瓶颈,不太好实时场景中使用。作者分析了多尺度变换器编码器中存在的计算冗余,设计了一系列变种来证明同时进行内部尺度和跨尺度特征交互在计算上效率低下。

在这里插入图片描述
A → B:变体B插入了一个单尺度的Transformer编码器,它使用了一个Transformer块的层。每个尺度的特征共享编码器,进行内部尺度的特征交互,然后将输出的多尺度特征进行连接。
B → C:变体C在B的基础上引入了基于尺度的特征融合,将连接的多尺度特征输入编码器进行特征交互。
C → D:变体D将多尺度特征的内部尺度交互和跨尺度融合解耦。首先,使用单尺度的Transformer编码器进行内部尺度交互,然后利用类似于PANet [21]的结构进行跨尺度融合。
D → E:变体E在D的基础上进一步优化多尺度特征的内部尺度交互和跨尺度融合,采用了我们设计的高效混合编码器。

RT-DETR认为S5特征相对于较浅的S3和S4特征来说,具有更深、更高级和更丰富的语义特征。这些语义特征对于Transformer模型更加重要,因为它们对于区分不同物体的特征非常有用,而浅层特征由于缺乏良好的语义特征并不是很丰富。

五、RTDERT模型训练(data–>train)

我将在此部分介绍环境安装、数据准备格式、训练相关配置与代码、预测相关内容与代码,我也将数据、官网提供权重放在这里,有需要自行下载。

1、环境安装

使用命令安装,如下:

conda create -n yolov8 python=3.8
conda activate yolov8
git clone https://github.com/ultralytics/ultralytics.git
cd ultralytics
pip install -r requirement.txt
pip install ultralytics

使用上面命令安装可能会报错Could not load library libcudnn_cnn_train.so.8 ,解决方法点击这里,建议先安装较低点的torch版本。

2、训练

我们使用yolov8集成的RTDETR模型,训练与预测文件大致如下图。
在这里插入图片描述

1、数据准备

实际为yolo数据格式,可按照yolov5或v8格式准备即可。

2、数据yaml文件

其数据yaml文件与yolo差不多,但少了nc且将names变成字典的映射,coco8.yaml内容如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# COCO8 dataset (first 8 images from COCO train2017) by Ultralytics
# Example usage: yolo train data=coco8.yaml
# parent
# ├── ultralytics
# └── datasets
#     └── coco8  ← downloads here (1 MB)# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: C:/Users/Administrator/Desktop/rtdetr/coco128  # dataset root dir
train: images/train  # train images (relative to 'path') 4 images
val: images/train  # val images (relative to 'path') 4 images
test:  # test images (optional)# Classes
names:0: person1: bicycle2: car3: motorcycle4: airplane5: bus6: train7: truck8: boat9: traffic light10: fire hydrant11: stop sign12: parking meter13: bench14: bird15: cat16: dog17: horse18: sheep19: cow20: elephant21: bear22: zebra23: giraffe24: backpack25: umbrella26: handbag27: tie28: suitcase29: frisbee30: skis31: snowboard32: sports ball33: kite34: baseball bat35: baseball glove36: skateboard37: surfboard38: tennis racket39: bottle40: wine glass41: cup42: fork43: knife44: spoon45: bowl46: banana47: apple48: sandwich49: orange50: broccoli51: carrot52: hot dog53: pizza54: donut55: cake56: chair57: couch58: potted plant59: bed60: dining table61: toilet62: tv63: laptop64: mouse65: remote66: keyboard67: cell phone68: microwave69: oven70: toaster71: sink72: refrigerator73: book74: clock75: vase76: scissors77: teddy bear78: hair drier79: toothbrush# Download script/URL (optional)
download: https://ultralytics.com/assets/coco8.zip

3、训练代码

我们使用命令训练,如下代码:

yolo train model=rtdetr-l.pt data=coco8.yaml epochs=100 imgsz=640 batch=2 amp=False

4、训练运行结果

配置好以上内容即可训练,执行过程如下显示
在这里插入图片描述

3、推理

1、推理代码

这里不在过多介绍推理代码,朋友们可自行查阅。

import cv2
import torch
import numpy as np
from ultralytics.nn.autobackend import AutoBackenddef preprocess(image):image = cv2.resize(image, (640, 640))image = (image[..., ::-1] / 255.0).astype(np.float32) # BGR to RGB, 0 - 255 to 0.0 - 1.0image = image.transpose(2, 0, 1)[None]  # BHWC to BCHW (n, 3, h, w)image = torch.from_numpy(image)return imagedef postprocess(pred, oh, ow, conf_thres=0.25):# 输入是模型推理的结果,即300个预测框# 1,300,84 [cx,cy,w,h,class*80]boxes = []for item in pred[0]:cx, cy, w, h = item[:4]label = item[4:].argmax()confidence = item[4 + label]if confidence < conf_thres:continueleft    = cx - w * 0.5top     = cy - h * 0.5right   = cx + w * 0.5bottom  = cy + h * 0.5boxes.append([left, top, right, bottom, confidence, label])boxes = np.array(boxes)lr = boxes[:,[0, 2]]tb = boxes[:,[1, 3]]boxes[:,[0,2]] = ow * lrboxes[:,[1,3]] = oh * tbreturn boxesdef hsv2bgr(h, s, v):h_i = int(h * 6)f = h * 6 - h_ip = v * (1 - s)q = v * (1 - f * s)t = v * (1 - (1 - f) * s)r, g, b = 0, 0, 0if h_i == 0:r, g, b = v, t, pelif h_i == 1:r, g, b = q, v, pelif h_i == 2:r, g, b = p, v, telif h_i == 3:r, g, b = p, q, velif h_i == 4:r, g, b = t, p, velif h_i == 5:r, g, b = v, p, qreturn int(b * 255), int(g * 255), int(r * 255)def random_color(id):h_plane = (((id << 2) ^ 0x937151) % 100) / 100.0s_plane = (((id << 3) ^ 0x315793) % 100) / 100.0return hsv2bgr(h_plane, s_plane, 1)if __name__ == "__main__":img = cv2.imread("1.jpg")oh, ow = img.shape[:2]img_pre = preprocess(img)# postprocess# ultralytics/models/rtdetr/predict.pymodel  = AutoBackend(weights="rtdetr-l.pt")names  = model.namesresult = model(img_pre)[0]  # 1,300,84boxes  = postprocess(result, oh, ow)for obj in boxes:left, top, right, bottom = int(obj[0]), int(obj[1]), int(obj[2]), int(obj[3])confidence = obj[4]label = int(obj[5])color = random_color(label)cv2.rectangle(img, (left, top), (right, bottom), color=color ,thickness=2, lineType=cv2.LINE_AA)caption = f"{names[label]} {confidence:.2f}"w, h = cv2.getTextSize(caption, 0, 1, 2)[0]cv2.rectangle(img, (left - 3, top - 33), (left + w + 10, top), color, -1)cv2.putText(img, caption, (left, top - 5), 0, 1, (0, 0, 0), 2, 16)cv2.imwrite("infer.jpg", img)print("save done")  

注:若下载了文件可直接 python detect.py执行,可得结果

2、推理运行结果

在这里插入图片描述


总结

文章主要是更换backbone(个人觉得不是文章重点),而使用S5在结合作者多个neck模块实验,该neck结构主打消除计算实现实时。
代码可使用百度官网代码,也可使用yolov8自带代码(高效实现)。
后期,我将仿yolov5一键训练与预测,直接使用xml文件格式训练有预测RTDETR文章。

参考博客链接:
https://blog.csdn.net/qq_40672115/article/details/134356250
https://blog.csdn.net/weixin_43694096/article/details/131353118

这篇关于RTDETR论文快速理解和代码快速实现(训练与预测)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/518496

相关文章

Java实现检查多个时间段是否有重合

《Java实现检查多个时间段是否有重合》这篇文章主要为大家详细介绍了如何使用Java实现检查多个时间段是否有重合,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录流程概述步骤详解China编程步骤1:定义时间段类步骤2:添加时间段步骤3:检查时间段是否有重合步骤4:输出结果示例代码结语作

使用C++实现链表元素的反转

《使用C++实现链表元素的反转》反转链表是链表操作中一个经典的问题,也是面试中常见的考题,本文将从思路到实现一步步地讲解如何实现链表的反转,帮助初学者理解这一操作,我们将使用C++代码演示具体实现,同... 目录问题定义思路分析代码实现带头节点的链表代码讲解其他实现方式时间和空间复杂度分析总结问题定义给定

Java覆盖第三方jar包中的某一个类的实现方法

《Java覆盖第三方jar包中的某一个类的实现方法》在我们日常的开发中,经常需要使用第三方的jar包,有时候我们会发现第三方的jar包中的某一个类有问题,或者我们需要定制化修改其中的逻辑,那么应该如何... 目录一、需求描述二、示例描述三、操作步骤四、验证结果五、实现原理一、需求描述需求描述如下:需要在

如何使用Java实现请求deepseek

《如何使用Java实现请求deepseek》这篇文章主要为大家详细介绍了如何使用Java实现请求deepseek功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1.deepseek的api创建2.Java实现请求deepseek2.1 pom文件2.2 json转化文件2.2

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

如何通过Python实现一个消息队列

《如何通过Python实现一个消息队列》这篇文章主要为大家详细介绍了如何通过Python实现一个简单的消息队列,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录如何通过 python 实现消息队列如何把 http 请求放在队列中执行1. 使用 queue.Queue 和 reque

Python如何实现PDF隐私信息检测

《Python如何实现PDF隐私信息检测》随着越来越多的个人信息以电子形式存储和传输,确保这些信息的安全至关重要,本文将介绍如何使用Python检测PDF文件中的隐私信息,需要的可以参考下... 目录项目背景技术栈代码解析功能说明运行结php果在当今,数据隐私保护变得尤为重要。随着越来越多的个人信息以电子形

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

使用Python快速实现链接转word文档

《使用Python快速实现链接转word文档》这篇文章主要为大家详细介绍了如何使用Python快速实现链接转word文档功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 演示代码展示from newspaper import Articlefrom docx import