mask2former利用不确定性采样点选择提高模型性能

2024-06-13 04:04

本文主要是介绍mask2former利用不确定性采样点选择提高模型性能,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在机器学习和深度学习的训练过程中,不确定性高的点通常代表模型在这些点上的预测不够可靠或有较高的误差。因此,关注这些不确定性高的点,通过计算这些点的损失并进行梯度更新,可以有效地提高模型的整体性能。确定性高的点预测结果已经比较准确,相应地对模型的训练贡献较小,所以可以减少对这些点的关注或完全忽略它们的损失计算。

代码复现参考仓库:https://github.com/NielsRogge/Transformers-Tutorials

在这篇博客中,我们将详细解释 mask2former 中的一段代码,该代码通过不确定性采样点来选择重要点,并探讨其在模型训练中的重要性。mask2former原文描述比较简单,如下:
在这里插入图片描述

代码源自transformers库中的modeling_mask2former.py,主要讲解如下代码:

    def sample_points_using_uncertainty(self,logits: torch.Tensor,uncertainty_function,num_points: int,oversample_ratio: int,importance_sample_ratio: float,) -> torch.Tensor:"""This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. Theuncertainty is calculated for each point using the passed `uncertainty function` that takes points logitprediction as input.Args:logits (`float`):Logit predictions for P points.uncertainty_function:A function that takes logit predictions for P points and returns their uncertainties.num_points (`int`):The number of points P to sample.oversample_ratio (`int`):Oversampling parameter.importance_sample_ratio (`float`):Ratio of points that are sampled via importance sampling.Returns:point_coordinates (`torch.Tensor`):Coordinates for P sampled points."""num_boxes = logits.shape[0]num_points_sampled = int(num_points * oversample_ratio)# Get random point coordinatespoint_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)# Get sampled prediction value for the point coordinatespoint_logits = sample_point(logits, point_coordinates, align_corners=False)# Calculate the uncertainties based on the sampled prediction values of the pointspoint_uncertainties = uncertainty_function(point_logits)#[n1+n2, 1, 37632],理解为,值越大,不确定性越高num_uncertain_points = int(importance_sample_ratio * num_points)#9408num_random_points = num_points - num_uncertain_points#3136idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]#[n1+n2, 9408]这行代码的作用是从每个 num_boxes 的不确定性值中选择 num_uncertain_points 个最大值的索引。这些索引将用于从原始的点坐标张量 point_coordinates 中选择相应的点,这些点将被认为是基于不确定性的重要性采样点。shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)#这两行代码的主要目的是确保在从 point_coordinates 中选择点时,能够正确地访问全局索引,使得每个 box 的采样点能够准确地映射到整个张量中的位置。idx += shift[:, None]point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)#[n1+n2, 9408, 2]if num_random_points > 0:point_coordinates = torch.cat([point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],dim=1,)return point_coordinates

以下是 sample_points_using_uncertainty 函数的参数解释:

  • logits (torch.Tensor): P 个点的 logit 预测值。
  • uncertainty_function: 一个函数,接受 P 个点的 logit 预测值并返回它们的不确定性。
  • num_points (int): 需要采样的点的数量 P。
  • oversample_ratio (int): 过采样参数,用于增加采样点的数量,以确保能在不确定性采样中选到合适的点。
  • importance_sample_ratio (float): 使用重要性采样选出的点的比例。

函数步骤解释

  1. 计算总采样点数

    num_boxes = logits.shape[0]
    num_points_sampled = int(num_points * oversample_ratio)
    

    num_boxes 是指预测的盒子数量,num_points_sampled 是经过过采样之后的总采样点数。

  2. 生成随机点的坐标

    point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
    

    在 [0, 1] * [0, 1] 空间内生成随机点的坐标。

  3. 获取这些随机点的预测值

    point_logits = sample_point(logits, point_coordinates, align_corners=False)
    

    对随机点的坐标进行采样,获取它们的预测 logit 值。

  4. 计算这些点的不确定性

    point_uncertainties = uncertainty_function(point_logits)
    

    使用 uncertainty_function 计算这些点的不确定性。

  5. 确定不确定性采样和随机采样的点数

    num_uncertain_points = int(importance_sample_ratio * num_points)
    num_random_points = num_points - num_uncertain_points
    

    根据 importance_sample_ratio 确定通过不确定性采样的点数 num_uncertain_points,以及剩余的随机采样点数 num_random_points

  6. 选择不确定性最高的点

    idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
    shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
    idx += shift[:, None]
    point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
    

    使用 torch.topk 函数选择每个盒子中不确定性最高的 num_uncertain_points 个点,并获取它们的坐标。

  7. 添加随机点

    if num_random_points > 0:point_coordinates = torch.cat([point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],dim=1,)
    

    如果需要添加随机采样点,将它们与不确定性采样点合并。

  8. 返回采样点的坐标

    return point_coordinates
    

    最终返回所有采样点的坐标。

关键代码解读

1. 偏移量的生成
 shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)

这行代码的目的是为每个 box 生成一个偏移量(shift),用于转换局部索引为全局索引。

  • torch.arange(num_boxes, dtype=torch.long, device=logits.device) 生成一个从 0 到 num_boxes-1 的张量。
  • num_points_sampled 是每个 box 中采样的点的数量。
  • 乘法操作 num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) 为每个 box 生成一个偏移量。例如,假设 num_points_sampled 为 100,那么生成的偏移量张量为 [0, 100, 200, 300, ...]

这些偏移量将用于将局部索引(即每个 box 内的索引)转换为全局索引(即在整个 point_coordinates 中的索引)。

2. 局部索引转换为全局索引
  idx += shift[:, None]

这行代码将局部索引转换为全局索引。

  • idxtorch.topk 返回的不确定性最高的点的局部索引,形状为 [num_boxes, num_uncertain_points]
  • shift[:, None] 的形状是 [num_boxes, 1],通过这种方式将每个 box 的偏移量广播到与 idx 的形状匹配。

通过将 shift 加到 idx 上,每个 box 的局部索引将变成全局索引。例如,如果第一个 box 的偏移量为 100,那么第一个 box 内的局部索引 [0, 1, 2, ...] 将变为 [100, 101, 102, ...]

总结

通过 sample_points_using_uncertainty 函数,我们可以有效地选择不确定性高的点进行训练,提高模型在这些关键点上的表现,同时减少确定性高的点的计算开销。这种不确定性采样方法结合了重要性采样和随机采样,确保了模型训练的高效性和鲁棒性。

这篇关于mask2former利用不确定性采样点选择提高模型性能的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056225

相关文章

Vue3 的 shallowRef 和 shallowReactive:优化性能

大家对 Vue3 的 ref 和 reactive 都很熟悉,那么对 shallowRef 和 shallowReactive 是否了解呢? 在编程和数据结构中,“shallow”(浅层)通常指对数据结构的最外层进行操作,而不递归地处理其内部或嵌套的数据。这种处理方式关注的是数据结构的第一层属性或元素,而忽略更深层次的嵌套内容。 1. 浅层与深层的对比 1.1 浅层(Shallow) 定义

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

性能测试介绍

性能测试是一种测试方法,旨在评估系统、应用程序或组件在现实场景中的性能表现和可靠性。它通常用于衡量系统在不同负载条件下的响应时间、吞吐量、资源利用率、稳定性和可扩展性等关键指标。 为什么要进行性能测试 通过性能测试,可以确定系统是否能够满足预期的性能要求,找出性能瓶颈和潜在的问题,并进行优化和调整。 发现性能瓶颈:性能测试可以帮助发现系统的性能瓶颈,即系统在高负载或高并发情况下可能出现的问题

如何选择适合孤独症兄妹的学校?

在探索适合孤独症儿童教育的道路上,每一位家长都面临着前所未有的挑战与抉择。当这份责任落在拥有孤独症兄妹的家庭肩上时,选择一所能够同时满足两个孩子特殊需求的学校,更显得尤为关键。本文将探讨如何为这样的家庭做出明智的选择,并介绍星贝育园自闭症儿童寄宿制学校作为一个值得考虑的选项。 理解孤独症儿童的独特性 孤独症,这一复杂的神经发育障碍,影响着儿童的社交互动、沟通能力以及行为模式。对于拥有孤独症兄

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

黑神话,XSKY 星飞全闪单卷性能突破310万

当下,云计算仍然是企业主要的基础架构,随着关键业务的逐步虚拟化和云化,对于块存储的性能要求也日益提高。企业对于低延迟、高稳定性的存储解决方案的需求日益迫切。为了满足这些日益增长的 IO 密集型应用场景,众多云服务提供商正在不断推陈出新,推出具有更低时延和更高 IOPS 性能的云硬盘产品。 8 月 22 日 2024 DTCC 大会上(第十五届中国数据库技术大会),XSKY星辰天合正式公布了基于星

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验