RegionCLIP网络结构解析 Region-based Language-Image Pretraining

本文主要是介绍RegionCLIP网络结构解析 Region-based Language-Image Pretraining,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、简单介绍

主要是关注目标检测方面的工作,现在纯CV已经前景黯淡,即使前段时间的YOLOv9发布也是关注一般。
现在大模型已成热点,而大模型要求的数据量和算力和算法复杂度,显然让很多人却步。但是具有大模型特点的多模态算法也算是研究的趋势,所以目前主要是关注多模态方面的目标检测工作。

其中目标检测领域,目前和多模态相关的主要是 开集、开放词汇、描述性目标检测以及情景理解等。相关的研究工作已经越来越多,这里权当学习记录。

RegionCLIP作为OVD检测算法,也是具有一定的代表性。

RegionCLIP的官方网址:https://github.com/microsoft/RegionCLIP
RegionCLIP的论文网址:https://arxiv.org/pdf/2112.09106.pdf

在这里插入图片描述

文章概述(摘自GitHub):

我们提出了 RegionCLIP,它显着扩展了 CLIP 以学习区域级的视觉表示。RegionCLIP支持图像区域和文本概念之间的细粒度对齐,从而支持基于区域的推理任务,包括零样本对象检测和开放词汇对象检测。

①预训练:我们利用 CLIP 模型将图像区域与模板标题进行匹配,然后预训练模型以对齐这些区域-文本对。

②零样本推理:预训练后,学习区域表示支持用于对象检测的零样本推理。

③迁移学习:学习的 RegionCLIP 模型可以通过额外的对象检测注释进一步微调,从而允许我们的模型用于完全监督或开放词汇的对象检测。

④结果:我们的方法展示了零样本目标检测和开放词汇目标检测的最新结果。

在这里插入图片描述

概括一下:核心思想就是把之前 图像特征和文本特征匹配的方式 聚焦到了 图像的局部区域特征 和文本特征的匹配

2、网络结构

大致看了代码,RegionCLIP是基于detectron2写的,包括预训练模型的训练和Fast RCNN结构的网络

2.1 预训练配置:

在这里插入图片描述
可以看到,这个预训练模型的结构是 PretrainFastRCNN
代码在
在这里插入图片描述
可以看到他的forward函数:

	def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):if not self.training:return self.inference(batched_inputs)gt_instances = Nonelosses = {}# localization branch: offline modules to get the region proposalsproposals = self.get_region_proposals(batched_inputs)global_proposals = self.create_global_proposals(batched_inputs)# recognition branch: get 2D feature maps using the backbone of recognition branch and extract region featuresimages = self.preprocess_image(batched_inputs)features = self.backbone(images.tensor)region_feats = self.get_region_features(images, features, proposals, gt_instances)global_feats = self.get_region_features(images, features, global_proposals, gt_instances)# image-text level matchingif self.img_txt_level:self.image_text_matching(batched_inputs, proposals, region_feats, losses, global_feats=global_feats)# region-concept level matchingif self.concept_emb is not None:self.region_concept_matching(images, proposals, gt_instances, region_feats, losses)return losses

从上可以看到区域选取是通过 self.get_region_proposals(batched_inputs) 实现的 ,

self.get_region_features(images, features, proposals, gt_instances) 这个是获取区域图像的特征

self.region_concept_matching(images, proposals, gt_instances, region_feats, losses) 是 区域图像特征和文本特征匹配的

2.2 CLIPFastRCNN 结构

如下配置文件,可以看到整体的网络配置
在这里插入图片描述

class CLIPFastRCNN(nn.Module):"""Fast R-CNN style where the cropping is conducted on feature maps instead of raw images.It contains the following two components: 1. Localization branch: pretrained backbone+RPN or equivalent modules, and is able to output object proposals2. Recognition branch: is able to recognize zero-shot regions"""@configurabledef __init__(self,*,offline_backbone: Backbone,backbone: Backbone,offline_proposal_generator: nn.Module,language_encoder: nn.Module, roi_heads: nn.Module,pixel_mean: Tuple[float],pixel_std: Tuple[float],input_format: Optional[str] = None,vis_period: int = 0,clip_crop_region_type: str = 'GT',use_clip_c4: False,use_clip_attpool: False,offline_input_format: Optional[str] = None,offline_pixel_mean: Tuple[float],offline_pixel_std: Tuple[float],):

这是定义的 CLIPFastRCNN 的初始内容,包含要传递的参数模块
其中backbone 是 build_clip_resnet_backbone,这个可以在如下找到

def build_backbone(cfg, input_shape=None):"""Build a backbone from `cfg.MODEL.BACKBONE.NAME`.Returns:an instance of :class:`Backbone`"""if input_shape is None:input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))backbone_name = cfg.MODEL.BACKBONE.NAMEbackbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape)assert isinstance(backbone, Backbone)return backbone

也就是通过 cfg.MODEL.BACKBONE.NAME 来定位到定义的backbone,如下:
在这里插入图片描述

可以看到,最终返回一个 ModifiedResNet

其中用了配置文件中的 MODEL.BACKBONE.FREEZE_ATMODEL.RESNETS.OUT_FEATURESMODEL.RESNETS.DEPTH具体如下:

def build_clip_resnet_backbone(cfg, input_shape):"""Create a CLIP-version ResNet instance from config.Returns:ModifiedResNet: a :class:`ModifiedResNet` instance."""# port standard ResNet config to CLIP ModifiedResNetfreeze_at           = cfg.MODEL.BACKBONE.FREEZE_ATout_features        = cfg.MODEL.RESNETS.OUT_FEATURESdepth               = cfg.MODEL.RESNETS.DEPTH# num_groups          = cfg.MODEL.RESNETS.NUM_GROUPS# width_per_group     = cfg.MODEL.RESNETS.WIDTH_PER_GROUP# bottleneck_channels = num_groups * width_per_group# in_channels         = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS# out_channels        = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS# stride_in_1x1       = cfg.MODEL.RESNETS.STRIDE_IN_1X1# res5_dilation       = cfg.MODEL.RESNETS.RES5_DILATION# deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE# deform_modulated    = cfg.MODEL.RESNETS.DEFORM_MODULATED# deform_num_groups   = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPSnum_blocks_per_stage = {18: [2, 2, 2, 2],34: [3, 4, 6, 3],50: [3, 4, 6, 3],101: [3, 4, 23, 3],152: [3, 8, 36, 3],200: [4, 6, 10, 6], # flag for ResNet50x4}[depth]vision_layers = num_blocks_per_stagevision_width = {50: 64,101: 64,200: 80, # flag for ResNet50x4}[depth]  # cfg.MODEL.RESNETS.STEM_OUT_CHANNELS# default configs of CLIP ModifiedResNet, but not used if only building ModifiedResNet as backboneembed_dim = {50: 1024,101: 512,200: 640, # flag for ResNet50x4}[depth] vision_heads = vision_width * 32 // 64image_resolution = {50: 224,101: 224,200: 288, # flag for ResNet50x4}[depth] # if combine {ModifiedResNet of CLIP, C4, text emb as classifier}, then has to use att_pool to match dimensioncreate_att_pool = True if (cfg.MODEL.ROI_HEADS.NAME in ['CLIPRes5ROIHeads', 'CLIPStandardROIHeads'] and cfg.MODEL.CLIP.USE_TEXT_EMB_CLASSIFIER)\or cfg.MODEL.ROI_HEADS.NAME == 'PretrainRes5ROIHeads' else Falsereturn ModifiedResNet(layers=vision_layers, output_dim=embed_dim,heads=vision_heads,input_resolution=image_resolution,width=vision_width,out_features=out_features, freeze_at=freeze_at,depth=depth,pool_vec=False,create_att_pool=create_att_pool,)

继续看 CLIPFastRCNN ,其中 @classmethod ----- 类方法让类模板具有记忆力,用@classmethod描述类方法,然后用"cls"代表本类。

    @classmethoddef from_config(cls, cfg):# create independent backbone & RPNif cfg.MODEL.CLIP.CROP_REGION_TYPE == "RPN": # create offline cfg for the pretrained backbone & RPNfrom detectron2.config import get_cfgoffline_cfg = get_cfg()offline_cfg.merge_from_file(cfg.MODEL.CLIP.OFFLINE_RPN_CONFIG)if cfg.MODEL.CLIP.OFFLINE_RPN_LSJ_PRETRAINED: # large-scale jittering (LSJ) pretrained RPNoffline_cfg.MODEL.BACKBONE.FREEZE_AT = 0 # make all fronzon layers to "SyncBN"offline_cfg.MODEL.RESNETS.NORM = "SyncBN" # 5 resnet layersoffline_cfg.MODEL.FPN.NORM = "SyncBN" # fpn layersoffline_cfg.MODEL.RPN.CONV_DIMS = [-1, -1] # rpn layersif cfg.MODEL.CLIP.OFFLINE_RPN_NMS_THRESH:offline_cfg.MODEL.RPN.NMS_THRESH = cfg.MODEL.CLIP.OFFLINE_RPN_NMS_THRESH  # 0.9if cfg.MODEL.CLIP.OFFLINE_RPN_POST_NMS_TOPK_TEST:offline_cfg.MODEL.RPN.POST_NMS_TOPK_TEST = cfg.MODEL.CLIP.OFFLINE_RPN_POST_NMS_TOPK_TEST # 1000# create offline backbone and RPNoffline_backbone = build_backbone(offline_cfg)offline_rpn = build_proposal_generator(offline_cfg, offline_backbone.output_shape())# convert to evaluation modefor p in offline_backbone.parameters(): p.requires_grad = Falsefor p in offline_rpn.parameters(): p.requires_grad = Falseoffline_backbone.eval()offline_rpn.eval()# region proposals are ground-truth boxeselif cfg.MODEL.CLIP.CROP_REGION_TYPE == "GT":offline_backbone = Noneoffline_rpn = Noneoffline_cfg = Nonebackbone = build_backbone(cfg)# build language encoderif cfg.MODEL.CLIP.GET_CONCEPT_EMB: # extract concept embeddingslanguage_encoder = build_clip_language_encoder(cfg)else:language_encoder = Noneroi_heads = build_roi_heads(cfg, backbone.output_shape())return {"offline_backbone": offline_backbone,"offline_proposal_generator": offline_rpn, "backbone": backbone,"language_encoder": language_encoder, "roi_heads": roi_heads, "input_format": cfg.INPUT.FORMAT,"vis_period": cfg.VIS_PERIOD,"pixel_mean": cfg.MODEL.PIXEL_MEAN,"pixel_std": cfg.MODEL.PIXEL_STD,"clip_crop_region_type" : cfg.MODEL.CLIP.CROP_REGION_TYPE,"use_clip_c4": cfg.MODEL.BACKBONE.NAME == "build_clip_resnet_backbone","use_clip_attpool": cfg.MODEL.ROI_HEADS.NAME in ['CLIPRes5ROIHeads', 'CLIPStandardROIHeads'] and cfg.MODEL.CLIP.USE_TEXT_EMB_CLASSIFIER,"offline_input_format": offline_cfg.INPUT.FORMAT if offline_cfg else None,"offline_pixel_mean": offline_cfg.MODEL.PIXEL_MEAN if offline_cfg else None,"offline_pixel_std": offline_cfg.MODEL.PIXEL_STD if offline_cfg else None,}

从上面可以看到 backbone ,language_encoder,roi_heads 构建相应的模块,基本上CLIPFastRCNN 的模块都在里面了。不过里面的 offline_backbone 让我疑惑,不知道这个是如何起作用的,发挥什么功能?我判断是加载离线模型 就是做过预训练的模型,用来生成proposals的,感觉这段代码不太好看,而且后面也不清晰怎么处理的。

还可以进一步看forward函数,直观了解数据处理:

    def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):"""Args:batched_inputs: a list, batched outputs of :class:`DatasetMapper` .Each item in the list contains the inputs for one image.For now, each item in the list is a dict that contains:* image: Tensor, image in (C, H, W) format.* instances (optional): groundtruth :class:`Instances`* proposals (optional): :class:`Instances`, precomputed proposals.Other information that's included in the original dicts, such as:* "height", "width" (int): the output resolution of the model, used in inference.See :meth:`postprocess` for details.Returns:list[dict]:Each dict is the output for one input image.The dict contains one key "instances" whose value is a :class:`Instances`.The :class:`Instances` object has the following keys:"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints""""if not self.training:return self.inference(batched_inputs)if "instances" in batched_inputs[0]:gt_instances = [x["instances"].to(self.device) for x in batched_inputs]else:gt_instances = None# localization branch: offline modules to get the region proposalswith torch.no_grad():  if self.clip_crop_region_type == "GT":  # from ground-truthproposals = []for r_i, b_input in enumerate(batched_inputs): this_gt = copy.deepcopy(b_input["instances"])  # Instancegt_boxes = this_gt._fields['gt_boxes'].to(self.device)this_gt._fields = {'proposal_boxes': gt_boxes, 'objectness_logits': torch.ones(gt_boxes.tensor.size(0)).to(self.device)}proposals.append(this_gt)                elif self.clip_crop_region_type == "RPN": # from the backbone & RPN of standard Mask-RCNN, trained on base classesif self.offline_backbone.training or self.offline_proposal_generator.training:  #  was set to True in training scriptself.offline_backbone.eval() self.offline_proposal_generator.eval()  images = self.offline_preprocess_image(batched_inputs)features = self.offline_backbone(images.tensor)if self.offline_proposal_generator is not None:proposals, _ = self.offline_proposal_generator(images, features, None)     # recognition branch: get 2D feature maps using the backbone of recognition branchimages = self.preprocess_image(batched_inputs)features = self.backbone(images.tensor)# Given the proposals, crop region features from 2D image features and classify the regionsif self.use_clip_c4: # use C4 + resnet weights from CLIPif self.use_clip_attpool: # use att_pool from CLIP to match dimension_, detector_losses = self.roi_heads(images, features, proposals, gt_instances, res5=self.backbone.layer4, attnpool=self.backbone.attnpool)else: # use mean pool_, detector_losses = self.roi_heads(images, features, proposals, gt_instances, res5=self.backbone.layer4)else:  # regular detector settingif self.use_clip_attpool: # use att_pool from CLIP to match dimension_, detector_losses = self.roi_heads(images, features, proposals, gt_instances, attnpool=self.backbone.bottom_up.attnpool)else: # use mean pool_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)if self.vis_period > 0:storage = get_event_storage()if storage.iter % self.vis_period == 0:self.visualize_training(batched_inputs, proposals)#visualize_proposals(batched_inputs, proposals, self.input_format)losses = {}losses.update(detector_losses)return losses

可以看到数据输入 batched_inputs 的处理,features = self.backbone(images.tensor) 这一步完成特征提取,features里包含了文本特征,后续进入 roi_heads 进行损失计算。

以上就是RegionCLIP的 CLIPFastRCNN 的网络结构对应代码解析。

这篇关于RegionCLIP网络结构解析 Region-based Language-Image Pretraining的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/875525

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

论文翻译:arxiv-2024 Benchmark Data Contamination of Large Language Models: A Survey

Benchmark Data Contamination of Large Language Models: A Survey https://arxiv.org/abs/2406.04244 大规模语言模型的基准数据污染:一项综述 文章目录 大规模语言模型的基准数据污染:一项综述摘要1 引言 摘要 大规模语言模型(LLMs),如GPT-4、Claude-3和Gemini的快

OWASP十大安全漏洞解析

OWASP(开放式Web应用程序安全项目)发布的“十大安全漏洞”列表是Web应用程序安全领域的权威指南,它总结了Web应用程序中最常见、最危险的安全隐患。以下是对OWASP十大安全漏洞的详细解析: 1. 注入漏洞(Injection) 描述:攻击者通过在应用程序的输入数据中插入恶意代码,从而控制应用程序的行为。常见的注入类型包括SQL注入、OS命令注入、LDAP注入等。 影响:可能导致数据泄

从状态管理到性能优化:全面解析 Android Compose

文章目录 引言一、Android Compose基本概念1.1 什么是Android Compose?1.2 Compose的优势1.3 如何在项目中使用Compose 二、Compose中的状态管理2.1 状态管理的重要性2.2 Compose中的状态和数据流2.3 使用State和MutableState处理状态2.4 通过ViewModel进行状态管理 三、Compose中的列表和滚动

Spring 源码解读:自定义实现Bean定义的注册与解析

引言 在Spring框架中,Bean的注册与解析是整个依赖注入流程的核心步骤。通过Bean定义,Spring容器知道如何创建、配置和管理每个Bean实例。本篇文章将通过实现一个简化版的Bean定义注册与解析机制,帮助你理解Spring框架背后的设计逻辑。我们还将对比Spring中的BeanDefinition和BeanDefinitionRegistry,以全面掌握Bean注册和解析的核心原理。

CSP 2023 提高级第一轮 CSP-S 2023初试题 完善程序第二题解析 未完

一、题目阅读 (最大值之和)给定整数序列 a0,⋯,an−1,求该序列所有非空连续子序列的最大值之和。上述参数满足 1≤n≤105 和 1≤ai≤108。 一个序列的非空连续子序列可以用两个下标 ll 和 rr(其中0≤l≤r<n0≤l≤r<n)表示,对应的序列为 al,al+1,⋯,ar​。两个非空连续子序列不同,当且仅当下标不同。 例如,当原序列为 [1,2,1,2] 时,要计算子序列 [

多线程解析报表

假如有这样一个需求,当我们需要解析一个Excel里多个sheet的数据时,可以考虑使用多线程,每个线程解析一个sheet里的数据,等到所有的sheet都解析完之后,程序需要提示解析完成。 Way1 join import java.time.LocalTime;public class Main {public static void main(String[] args) thro

ZooKeeper 中的 Curator 框架解析

Apache ZooKeeper 是一个为分布式应用提供一致性服务的软件。它提供了诸如配置管理、分布式同步、组服务等功能。在使用 ZooKeeper 时,Curator 是一个非常流行的客户端库,它简化了 ZooKeeper 的使用,提供了高级的抽象和丰富的工具。本文将详细介绍 Curator 框架,包括它的设计哲学、核心组件以及如何使用 Curator 来简化 ZooKeeper 的操作。 1