【目标检测】DEtection TRansformer (DETR)

2024-05-03 15:12

本文主要是介绍【目标检测】DEtection TRansformer (DETR),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、前言

论文: End-to-End Object Detection with Transformers
作者: Facebook AI
代码: DEtection TRansformer (DETR)
特点: 无proposal(R-CNN系列)、无anchor(YOLO系列)、无NMS的、端到端的目标检测方法。

二、框架

DETR总体框架图如下:

可见,其主要结构包括四个部分:backbone、encoder、decoder、prediction heads。

2.1 Backbone

输入图像先经过backbone进行特征提取,原文使用ResNet-50。此时,通道数变为2048,图像高宽变为原来的 1 32 \frac{1}{32} 321。再经过一个卷积核大小为1*1的卷积层,将通道数降低至256。

尺寸变换情况如下(同一批次的图像会经过padding统一大小):
[ b a t c h _ s i z e , 3 , h e i g h t , w i d t h ] → [ b a t c h _ s i z e , 2048 , h e i g h t / 32 , w i d t h / 32 ] → [ b a t c h _ s i z e , 256 , h e i g h t / 32 , w i d t h / 32 ] [batch \_size,3,height,width]\rightarrow[batch\_size,2048,height/32,width/32]\rightarrow[batch\_size,256,height/32,width/32] [batch_size,3,height,width][batch_size,2048,height/32,width/32][batch_size,256,height/32,width/32]

2.2 Encoder

Encoder结构如下图左侧所示:

可见,Encoder包括 N N N个这样的组件(原文中有6个),每个组件包括Spatial positional encoding、残差结构(当前输出+之前的输入)、Multi-Head Self-Attention、LayerNorm、FNN(全连接+激活+Dropout+全连接+Dropout)。

值得注意的是:
(1) DETR的位置编码采用了正余弦交替表达各像素点横纵坐标的方式,详情见我关于位置编码的博客(Spatial positional encoding)。
(2) DETR的位置编码仅加在了注意力模块中的Q、K上,这表明计算权重时会使用位置信息,但被传递至下一层的数据中不包含位置信息。原注意力模块的位置编码在Q、K、V上均有体现,详情见我关于注意力的博客(Multi-Head Self-Attention)。

2.3 Decoder

Decoder结构如下图右侧所示:

Decoder重复次数 M M M也是6,其包含的组件主要有位置编码、残差结构、Multi-Head Self-Attention、LayerNorm、Multi-Head Attention、FNN(全连接+激活+Dropout+全连接+Dropout)。

需要注意的有以下几点:
(1) Decoder的输入变为Object queries。Object queries是一个大小为100*256、初始全为0的可学习参数。100表示模型最多预测出100个目标框,256与图像特征通道数一致可保证注意力机制的正常运算。Decoder对应Object queries的输出经Prediction heads后将用于计算损失、预测框坐标、预测类别。
(2) Decoder中的位置编码与Object queries尺寸是一致的。没有使用Spatial positional encoding,而是由nn.Embedding随机初始化,全程保持不变。Decoder中的位置编码作用于Multi-Head Self-Attention前的Q、K和Multi-Head Attention前的Q。
(3) Decoder中有两个注意力模块,先通过Multi-Head Self-Attention,再通过Multi-Head Attention。Multi-Head Self-Attention是对Object queries执行自注意力,Decoder中的位置编码仅作用于Q、K。Multi-Head Attention以Object queries为Q,以Encoder的输出为K和V。Decoder中的位置编码作用于Q,Encoder中的位置编码作用于K。

2.4 Prediction heads

Prediction heads结构如下图右上部分所示:

可见,与其他目标检测方法一样,DETR也是对类别和边界框进行预测。

类别预测头的FFN是一层简单的全连接,例如在COCO数据集中为 256 → 92 256\rightarrow 92 25692(92=类别总数91+背景类1,实际COCO为80个有效类但给了91个类)。

边界框预测头的FFN是一个MLP,包括三层全连接和一层激活: 256 → 256 → 256 → 4 → s i g m o i d 256\rightarrow256\rightarrow256\rightarrow4\rightarrow sigmoid 2562562564sigmoid(4=左上角坐标2+右下角坐标2)。预测的坐标是归一化的,实际计算时需要映射至原图。

三、训练

如上所述,预测结果共有100个,每个都有对类别和边界框的预测。但是实际一张图像中目标数量通常不足100,DETR通过二分匹配为每个真实目标寻找一个最匹配的预测用于框相关损失的计算。

3.1 二分匹配

DETR使用匈牙利匹配算法,为每个真实目标寻找一个最匹配的预测。想要进行匹配首先要有对每个预测结果的衡量指标,DETR使用了三种指标:
(1) 在真实目标类上的预测概率(分类头输出经SoftMax后获得)。
(2) 所有预测框坐标与所有真实目标框坐标间的曼哈顿距离(残差的绝对值之和,即L1损失)。
(3) 所有预测框与所有真实目标框间的GIOU(一种改进的IoU指标)。
其中,(1)和(3)越大越好,(2)越小越好,所以衡量指标被定义为 − ( 1 ) + ( 2 ) − ( 3 ) -(1)+(2)-(3) (1)+(2)(3)

匈牙利算法能够根据衡量指标为每个真实目标都找的一个最匹配的预测。所以DETR通过二分匹配而非NMS确定用于计算框相关损失的预测。

3.2 损失

损失包括如下三项:
(1) 交叉熵损失。每张图片有100个预测,未被匹配的预测所对应的真实标签被置为背景类91。
(2) L1损失。残差的绝对值之和。
(3) GIOU损失。 1 − G I O U 1-GIOU 1GIOU
交叉熵损失针对所有预测,L1损失和GIOU损失仅针对与真实目标匹配的预测。交叉熵损失仅针对类别预测,L1损失和GIOU损失仅针对框预测。

四、测试

测试时无需计算损失,也不需要NMS,直接保留类别预测概率大于阈值的预测即可。

这篇关于【目标检测】DEtection TRansformer (DETR)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/957033

相关文章

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

3月份目标——刷完乙级真题

https://www.patest.cn/contests/pat-b-practisePAT (Basic Level) Practice (中文) 标号标题通过提交通过率1001害死人不偿命的(3n+1)猜想 (15)31858792260.41002写出这个数 (20)21702664840.331003我要通过!(20)11071447060.251004成绩排名 (20)159644

基于深度学习的轮廓检测

基于深度学习的轮廓检测 轮廓检测是计算机视觉中的一项关键任务,旨在识别图像中物体的边界或轮廓。传统的轮廓检测方法如Canny边缘检测和Sobel算子依赖于梯度计算和阈值分割。而基于深度学习的方法通过训练神经网络来自动学习图像中的轮廓特征,能够在复杂背景和噪声条件下实现更精确和鲁棒的检测效果。 深度学习在轮廓检测中的优势 自动特征提取:深度学习模型能够自动从数据中学习多层次的特征表示,而不需要

自动驾驶---Perception之Lidar点云3D检测

1 背景         Lidar点云技术的出现是基于摄影测量技术的发展、计算机及高新技术的推动以及全球定位系统和惯性导航系统的发展,使得通过激光束获取高精度的三维数据成为可能。随着技术的不断进步和应用领域的拓展,Lidar点云技术将在测绘、遥感、环境监测、机器人等领域发挥越来越重要的作用。         目前全球范围内纯视觉方案的车企主要包括特斯拉和集越,在达到同等性能的前提下,纯视觉方

YOLOv9摄像头或视频实时检测

1、下载yolov9的项目 地址:YOLOv9 2、使用下面代码进行检测 import torchimport cv2from models.experimental import attempt_loadfrom utils.general import non_max_suppression, scale_boxesfrom utils.plots import plot_o

2025秋招NLP算法面试真题(二)-史上最全Transformer面试题:灵魂20问帮你彻底搞定Transformer

简单介绍 之前的20个问题的文章在这里: https://zhuanlan.zhihu.com/p/148656446 其实这20个问题不是让大家背答案,而是为了帮助大家梳理 transformer的相关知识点,所以你注意看会发现我的问题也是有某种顺序的。 本文涉及到的代码可以在这里找到: https://github.com/DA-southampton/NLP_ability 问题

Java内存泄漏检测和分析介绍

在Java中,内存泄漏检测和分析是一个重要的任务,可以通过以下几种方式进行:   1. 使用VisualVM VisualVM是一个可视化工具,可以监控、分析Java应用程序的内存消耗。它可以显示堆内存、垃圾收集、线程等信息,并且可以对内存泄漏进行分析。 2. 使用Eclipse Memory Analyzer Eclipse Memory Analyzer(MAT)是一个强大的工具,可

Failed to establish a new connection: [WinError 10061] 由于目标计算机积极拒绝,无法连接

在进行参数化读取时发现一个问题: 发现问题: requests.exceptions.ConnectionError: HTTPConnectionPool(host='localhost', port=8081): Max retries exceeded with url: /jwshoplogin/user/update_information.do (Caused by NewConn

基于CDMA的多用户水下无线光通信(3)——解相关多用户检测

继续上一篇博文,本文将介绍基于解相关的多用户检测算法。解相关检测器的优点是因不需要估计各个用户的接收信号幅值而具有抗远近效应的能力。常规的解相关检测器有运算量大和实时性差的缺点,本文针对异步CDMA的MAI主要来自干扰用户的相邻三个比特周期的特点,给出了基于相邻三个匹配滤波器输出数据的截断解相关检测算法。(我不知道怎么改公式里的字体,有的字母在公式中重复使用了,请根据上下文判断字母含义) 1

算是一些Transformer学习当中的重点内容

一、基础概念         Transformer是一种神经网络结构,由Vaswani等人在2017年的论文Attentions All YouNeed”中提出,用于处理机器翻译、语言建模和文本生成等自然语言处理任务。Transformer同样是encoder-decoder的结构,只不过这里的“encoder”和“decoder”是由无数个同样结构的encoder层和decoder层堆叠组成