从DETR到Mask2Former(1):DETR-segmentation结构全解析

2024-01-12 22:44

本文主要是介绍从DETR到Mask2Former(1):DETR-segmentation结构全解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

网上关于DETR做的detection的解析很多,但是DETR做Segmentation的几乎没有,本文结合DETR的论文与代码,对DETR做一个详细的拆解。理解DETR是理解Mask2Former的基础。

首先得把DETR-segmentation给run起来。Github上DETR的repository,下载了也只能run起来detection,run不起来segmentation功能,但还是下载下来,后面留着有用。我们用torch的hub里集成的DETR segmentation模型,运行下面的代码

import torch
models_list = torch.hub.list('facebookresearch/detr', force_reload=True)
print(models_list)

你可以看到torch.hub中所有关于detr的模型。我们选择  detr_resnet50_panoptic

model = torch.hub.load('facebookresearch/detr', 'detr_resnet50_panoptic', pretrained=True)

再新建一个py文件,把以下代码放进去:

import math
from PIL import Image
import requests
import matplotlib.pyplot as pltimport ipywidgets as widgets
from IPython.display import display, clear_outputimport torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)import matplotlib.pyplot as plt# COCO classes
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A','stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack','umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis','snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove','skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass','cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich','orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake','chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A','N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard','cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A','book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier','toothbrush'
]# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]transform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return bdef plot_results(pil_img, prob, boxes):plt.figure(figsize=(16,10))plt.imshow(pil_img)ax = plt.gca()colors = COLORS * 100for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))# cl = p.argmax()# text = f'{CLASSES[cl]}: {p[cl]:0.2f}'# ax.text(xmin, ymin, text, fontsize=15,#         bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.show()model = torch.hub.load('facebookresearch/detr', 'detr_resnet50_panoptic', pretrained=True)
model.eval()url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)# propagate through the model
outputs = model(img)# show result of detection
# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
plot_results(im, probas[keep], bboxes_scaled)# show segmentation
# compute the scores, excluding the "no-object" class (the last one)
scores = outputs['pred_logits']
scores = scores.softmax(dim=-1)
scores = scores[..., :-1]
scores_onehot = scores.max(-1)
scores_onehot = scores_onehot[0]
# threshold the confidence
keep = scores_onehot > 0.85scores_selected = scores[keep]
labels = torch.argmax(scores_selected, dim=-1)masks = outputs['pred_masks'][keep].detach().cpu().numpy()
mask_i = masks[0, :, :]# plt.imshow(mask_i, cmap="viridis")
# plt.show()# ## Plot all the remaining masks
ncols = 5
fig, axs = plt.subplots(ncols=ncols, nrows=math.ceil(keep.sum().item() / ncols), figsize=(18, 10))
for line in axs:for a in line:a.axis('off')
for i, mask in enumerate(outputs['pred_masks'][keep].detach().cpu().numpy()):ax = axs[i // ncols, i % ncols]ax.imshow(mask, cmap="cividis")ax.text(0, 0, labels[i].cpu().numpy().item(0))ax.axis('off')
fig.tight_layout()plt.show()mask_pred = outputs['pred_masks'].sigmoid()
mask_pred = F.interpolate(mask_pred, size=(480, 640), mode='bilinear', align_corners=False)semseg = torch.einsum("bqc,bqhw->bchw", scores, mask_pred)
result = torch.argmax(semseg, dim=1)plt.figure(figsize=(12, 8))
# 第一个子图
plt.subplot(1, 1, 1)
# image1_rgb = cv2.cvtColor(img_color, cv2.COLOR_BGR2RGB)
plt.imshow(result[0, :, :].to("cpu").numpy())
plt.title('Image 1')
plt.axis('off')plt.show()

解释一下上面的代码,

plot_results(im, probas[keep], bboxes_scaled)

是把detection的结果可视化出来,本文略过的detection部分的内容。  plot_results()  下面的代码就是segmentation内容结果的可视化。模型的输出包括三个内容

pred_logits(1, 100, 251)中三个数字的含义是:1是batchsize,100是query的个数,251是分类数,去掉最后一个no-object不要,实际是250个类。可以这样理解,每一个query都会进入网络,但不是每一个query都能从图片找到东西,而找到东西的query,找到的instance所属于的类就是这个query所在的那一行中,最大的数所在列的index。如果一个图片里有两只猫,那结果就是有两个query,分别各自找到一只猫(也就是一个instance),这样也就实现了全景分割的功能。所以说DETR的结构是Mask2Former实现全景分割功能的基础。

pred_boxes是query所找到的instance对应的box。

pred_masks是query所找到的instance所在的像素。

这段代码:

# ## Plot all the remaining masks
ncols = 5
fig, axs = plt.subplots(ncols=ncols, nrows=math.ceil(keep.sum().item() / ncols), figsize=(18, 10))
for line in axs:for a in line:a.axis('off')
for i, mask in enumerate(outputs['pred_masks'][keep].detach().cpu().numpy()):ax = axs[i // ncols, i % ncols]ax.imshow(mask, cmap="cividis")ax.text(0, 0, labels[i].cpu().numpy().item(0))ax.axis('off')
fig.tight_layout()plt.show()

会画出这样的图像:

有94个query没有找到instance,有6个query找到了instance,2个query找到了猫,也就是250列的17列,两个query找到了遥控器(250列的74列)。

好了,现在DETR-seg的输出我们弄清楚了,接下来进到DETR内部去看看,这个模型封装在了torch.hub中,进入内部的正确方法,说实话笔者也不知道,这里笔者耍了一个不正经的小trick:

import torchmodel = torch.hub.load('facebookresearch/detr', 'detr_resnet50_panoptic', pretrained=True)from models.segmentation import DETRsegmmodel = DETRsegm(model)

上面的代码中,models.segmentation文件是DETR的github repository下载下来,里面有的内容,运行上面的代码,会报错

但是通过这个报错,我们可以找到torch.hub中DETR的源代码。这样就可以在源代码里打断点,看DETR 内部了。进入源代码,我们能看到这样的内容,注释的代码是笔者加进去的,不是torch自带的。

self.detr.backbone是resnet18。 这里,我们打开DETR的github代码——detr/models/backbone.py 找到  class BackboneBase 这个类。

这里的 def forward 中的out,就是 DETRsegm 中的 features

features, pos = self.detr.backbone(samples)

pos是正弦函数的 position embeding,就是输入transformer的encoder的位置编码。

在获取features时,DETR detection和Segmentation的区别在与,detection只拿resnet第四层的输出,而segmentation将每层的输出都拿出来。通过这段代码实现(缩进乱了,忽略)。

if return_interm_layers:return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}else:return_layers = {'layer4': "0"}self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)

这里扯远一点,复习一下resnet18  resnet18在torch中也有集成,通过下面的代码得到,

import torch
import torchvision.models as models
from torchvision.models._utils import IntermediateLayerGetter
import cv2
import numpy as np# 加载一个预训练的 ResNet 模型
model = models.resnet18(pretrained=True)
model.to("cuda")# 定义要获取的中间层
layers = {'layer1': 'layer1', 'layer2': 'layer2', 'layer3': 'layer3', 'layer4': 'layer4'}# 创建 IntermediateLayerGetter
intermediate_layers = IntermediateLayerGetter(model, layers)image0 = cv2.imread("/home/robotics/dino/img/kitaku/002.jpg")
height0, width0, channels = image0.shape
image = cv2.resize(image0, (640, 480), interpolation=cv2.INTER_AREA)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)mean = np.array([0.485, 0.456, 0.406]) * 255
std = np.array([0.229, 0.224, 0.225]) * 255
image = image.astype(float)
for i in range(3):image[:, :, i] = (image[:, :, i] - mean[i]) / std[i]input_data = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(dim=0).to(torch.float32).to("cuda")# 获取中间层的输出
outputs = intermediate_layers(input_data)# 打印输出
for key, value in outputs.items():print(f"{key}: {value.shape}")from torchview import draw_graph
model_graph = draw_graph(model, input_size=(1, 3, 480, 640))
model_graph.resize_graph(scale=5.0)
model_graph.visual_graph.render(format='svg')

通过draw_graph得到一个网络结构可视化图,格式是svg,如下图所示

也可以自己写代码,构建一个resnet18

import torch
import torch.nn as nn
from torch.nn import functional as Fclass RestNetBasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride):super(RestNetBasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, x):output = self.conv1(x)output = F.relu(self.bn1(output))output = self.conv2(output)output = self.bn2(output)return F.relu(x + output)class RestNetDownBlock(nn.Module):def __init__(self, in_channels, out_channels, stride):super(RestNetDownBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, x):output = self.conv1(x)out = F.relu(self.bn1(output))out = self.conv2(out)out = self.bn2(out)return F.relu(x + out)class RestNet18(nn.Module):def __init__(self):super(RestNet18, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),RestNetBasicBlock(64, 64, 1))self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),RestNetBasicBlock(128, 128, 1))self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),RestNetBasicBlock(256, 256, 1))self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]),RestNetBasicBlock(512, 512, 1))self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))self.fc = nn.Linear(512, 10)def forward(self, x):out = self.conv1(x)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.reshape(x.shape[0], -1)out = self.fc(out)return outif __name__ == "__main__":resnet18 = RestNet18()from torchview import draw_graphmodel_graph = draw_graph(resnet18, input_size=(1, 3, 480, 640))model_graph.resize_graph(scale=5.0)model_graph.visual_graph.render(format='svg')

好,回到DETR,还是看torch.hub中的这段detr代码:

 上面的代码中:

hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1])

self.detr.transformer就是下图经典的transformer结构,memory是encoder的输出  hs是decoder的输出。mask对于理解detr不重要,是masked attention中用到了。

pos和query_embed.weight分别对应上图的Positional Encoding和decoder中,output上面的OutputEmbeding。

DETR中transformer的理解,可以参看下面的代码,这段代码是笔者自己写的,对于DETR本身是没有用处的,只是为了方便理解,有错误的地方(feature那里只拿出resnet18最后一层的输出,是错的,不影响大局,懒得改了),自己理解时候做一个参考就好。

import torch
from torch import nn
from torchvision.models import resnet50
from models.transformer import TransformerEncoderLayer, TransformerEncoder, TransformerDecoderLayer, TransformerDecoder
from models.segmentation import MHAttentionMap, MaskHeadSmallConv
import timeclass DETR(nn.Module):def __init__(self, num_classes, d_model, nheads, dim_feedforward=2048,num_encoder_layers=6, num_decoder_layers=6, dropout=0.9, activation='relu', normalize_before=True):super().__init__()# We take only convolutional layers from ResNet-50 modelself.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])self.conv = nn.Conv2d(dim_feedforward, d_model, 1)encoder_layer = TransformerEncoderLayer(d_model, nheads, dim_feedforward,dropout, activation, normalize_before)encoder_norm = nn.LayerNorm(d_model) if normalize_before else Noneself.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)decoder_layer = TransformerDecoderLayer(d_model, nheads, dim_feedforward,dropout, activation, normalize_before)decoder_norm = nn.LayerNorm(d_model)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,return_intermediate=True)self.linear_class = nn.Linear(d_model, num_classes + 1)self.linear_bbox = nn.Linear(d_model, 4)num_queries = 100self.query_embed = nn.Embedding(num_queries, d_model)#position embeddingself.row_embed = nn.Parameter(torch.rand(50, d_model // 2))self.col_embed = nn.Parameter(torch.rand(50, d_model // 2))self.bbox_attention = MHAttentionMap(d_model, d_model, nheads, dropout=0.0)self.mask_head = MaskHeadSmallConv(d_model + nheads, [1024, 512, 256], d_model)def forward(self, inputs):#inputs是[1,3,800,1200]features = self.backbone(inputs)#x是[1,2048,25,38]hh = self.conv(features)#hh是[1,256,25,38]H, W = hh.shape[-2:]pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)bs, c, h, w = hh.shapequery_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)tgt = torch.zeros_like(query_embed)src = hh.flatten(2).permute(2, 0, 1)memory = self.encoder(src=src, pos=pos)hs = self.decoder(tgt, memory, pos=pos, query_pos=query_embed)hs = hs.transpose(1, 2)memory = memory.permute(1, 2, 0).view(bs, c, h, w)bbox_mask = self.bbox_attention(hs[-1], memory)seg_masks = self.mask_head(hh, bbox_mask, [features[0], features[0], features[0]])outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])#self.query_pos是[100,256]#src是encoder输入,tgt是decoder输入#h是[100,1,256]return outputs_seg_masks#coco是91个类, hidden dimension是256, 多头注意力是8, encoder,decoder layer都是6
device = torch.device("cuda")
detr = DETR(num_classes=91, d_model=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval().cuda()inputs = torch.randn(1, 3, 800, 1200).cuda()outputs_seg_masks = detr(inputs)
# print(logits, bboxes)
#logits是[100,1,92]
#bboxes是[100,1,4]

还是回到torch.hub那个detr代码,下面结合DETR论文中的那张图来理解一下这段代码

图中multi head attention部分,对应代码中

bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)

hs[-1]是左侧四个不同颜色的小框,memory是Encoded image,mask可以忽略,bbox_mask的尺寸是:(batch_size,query的个数,注意力头的个数,attention map的高和宽)

        temp = bbox_mask[0][20].squeeze().cpu().numpy()import matplotlib.pyplot as pltfor i in range(8):plt.imshow(temp[i, :, :], cmap="viridis")plt.show()

上面那段代码就是拿出第1个batchsize的第21个query的所有注意力头的attention map,也就是可以画出这几张图:

 

seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])

mask_head对应图中这个结构:

输入是多头注意力map,以及resnet18这个backbone的4层的特征。

本文先到这,下一篇写DETR的损失函数。

未完待续

这篇关于从DETR到Mask2Former(1):DETR-segmentation结构全解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/599398

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

usaco 1.3 Mixing Milk (结构体排序 qsort) and hdu 2020(sort)

到了这题学会了结构体排序 于是回去修改了 1.2 milking cows 的算法~ 结构体排序核心: 1.结构体定义 struct Milk{int price;int milks;}milk[5000]; 2.自定义的比较函数,若返回值为正,qsort 函数判定a>b ;为负,a<b;为0,a==b; int milkcmp(const void *va,c

自定义类型:结构体(续)

目录 一. 结构体的内存对齐 1.1 为什么存在内存对齐? 1.2 修改默认对齐数 二. 结构体传参 三. 结构体实现位段 一. 结构体的内存对齐 在前面的文章里我们已经讲过一部分的内存对齐的知识,并举出了两个例子,我们再举出两个例子继续说明: struct S3{double a;int b;char c;};int mian(){printf("%zd\n",s

OWASP十大安全漏洞解析

OWASP(开放式Web应用程序安全项目)发布的“十大安全漏洞”列表是Web应用程序安全领域的权威指南,它总结了Web应用程序中最常见、最危险的安全隐患。以下是对OWASP十大安全漏洞的详细解析: 1. 注入漏洞(Injection) 描述:攻击者通过在应用程序的输入数据中插入恶意代码,从而控制应用程序的行为。常见的注入类型包括SQL注入、OS命令注入、LDAP注入等。 影响:可能导致数据泄

从状态管理到性能优化:全面解析 Android Compose

文章目录 引言一、Android Compose基本概念1.1 什么是Android Compose?1.2 Compose的优势1.3 如何在项目中使用Compose 二、Compose中的状态管理2.1 状态管理的重要性2.2 Compose中的状态和数据流2.3 使用State和MutableState处理状态2.4 通过ViewModel进行状态管理 三、Compose中的列表和滚动

Spring 源码解读:自定义实现Bean定义的注册与解析

引言 在Spring框架中,Bean的注册与解析是整个依赖注入流程的核心步骤。通过Bean定义,Spring容器知道如何创建、配置和管理每个Bean实例。本篇文章将通过实现一个简化版的Bean定义注册与解析机制,帮助你理解Spring框架背后的设计逻辑。我们还将对比Spring中的BeanDefinition和BeanDefinitionRegistry,以全面掌握Bean注册和解析的核心原理。

CSP 2023 提高级第一轮 CSP-S 2023初试题 完善程序第二题解析 未完

一、题目阅读 (最大值之和)给定整数序列 a0,⋯,an−1,求该序列所有非空连续子序列的最大值之和。上述参数满足 1≤n≤105 和 1≤ai≤108。 一个序列的非空连续子序列可以用两个下标 ll 和 rr(其中0≤l≤r<n0≤l≤r<n)表示,对应的序列为 al,al+1,⋯,ar​。两个非空连续子序列不同,当且仅当下标不同。 例如,当原序列为 [1,2,1,2] 时,要计算子序列 [

多线程解析报表

假如有这样一个需求,当我们需要解析一个Excel里多个sheet的数据时,可以考虑使用多线程,每个线程解析一个sheet里的数据,等到所有的sheet都解析完之后,程序需要提示解析完成。 Way1 join import java.time.LocalTime;public class Main {public static void main(String[] args) thro

ZooKeeper 中的 Curator 框架解析

Apache ZooKeeper 是一个为分布式应用提供一致性服务的软件。它提供了诸如配置管理、分布式同步、组服务等功能。在使用 ZooKeeper 时,Curator 是一个非常流行的客户端库,它简化了 ZooKeeper 的使用,提供了高级的抽象和丰富的工具。本文将详细介绍 Curator 框架,包括它的设计哲学、核心组件以及如何使用 Curator 来简化 ZooKeeper 的操作。 1