mask2former利用不确定性采样点选择提高模型性能

2024-06-13 04:04

本文主要是介绍mask2former利用不确定性采样点选择提高模型性能,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在机器学习和深度学习的训练过程中,不确定性高的点通常代表模型在这些点上的预测不够可靠或有较高的误差。因此,关注这些不确定性高的点,通过计算这些点的损失并进行梯度更新,可以有效地提高模型的整体性能。确定性高的点预测结果已经比较准确,相应地对模型的训练贡献较小,所以可以减少对这些点的关注或完全忽略它们的损失计算。

代码复现参考仓库:https://github.com/NielsRogge/Transformers-Tutorials

在这篇博客中,我们将详细解释 mask2former 中的一段代码,该代码通过不确定性采样点来选择重要点,并探讨其在模型训练中的重要性。mask2former原文描述比较简单,如下:
在这里插入图片描述

代码源自transformers库中的modeling_mask2former.py,主要讲解如下代码:

    def sample_points_using_uncertainty(self,logits: torch.Tensor,uncertainty_function,num_points: int,oversample_ratio: int,importance_sample_ratio: float,) -> torch.Tensor:"""This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. Theuncertainty is calculated for each point using the passed `uncertainty function` that takes points logitprediction as input.Args:logits (`float`):Logit predictions for P points.uncertainty_function:A function that takes logit predictions for P points and returns their uncertainties.num_points (`int`):The number of points P to sample.oversample_ratio (`int`):Oversampling parameter.importance_sample_ratio (`float`):Ratio of points that are sampled via importance sampling.Returns:point_coordinates (`torch.Tensor`):Coordinates for P sampled points."""num_boxes = logits.shape[0]num_points_sampled = int(num_points * oversample_ratio)# Get random point coordinatespoint_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)# Get sampled prediction value for the point coordinatespoint_logits = sample_point(logits, point_coordinates, align_corners=False)# Calculate the uncertainties based on the sampled prediction values of the pointspoint_uncertainties = uncertainty_function(point_logits)#[n1+n2, 1, 37632],理解为,值越大,不确定性越高num_uncertain_points = int(importance_sample_ratio * num_points)#9408num_random_points = num_points - num_uncertain_points#3136idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]#[n1+n2, 9408]这行代码的作用是从每个 num_boxes 的不确定性值中选择 num_uncertain_points 个最大值的索引。这些索引将用于从原始的点坐标张量 point_coordinates 中选择相应的点,这些点将被认为是基于不确定性的重要性采样点。shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)#这两行代码的主要目的是确保在从 point_coordinates 中选择点时,能够正确地访问全局索引,使得每个 box 的采样点能够准确地映射到整个张量中的位置。idx += shift[:, None]point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)#[n1+n2, 9408, 2]if num_random_points > 0:point_coordinates = torch.cat([point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],dim=1,)return point_coordinates

以下是 sample_points_using_uncertainty 函数的参数解释:

  • logits (torch.Tensor): P 个点的 logit 预测值。
  • uncertainty_function: 一个函数,接受 P 个点的 logit 预测值并返回它们的不确定性。
  • num_points (int): 需要采样的点的数量 P。
  • oversample_ratio (int): 过采样参数,用于增加采样点的数量,以确保能在不确定性采样中选到合适的点。
  • importance_sample_ratio (float): 使用重要性采样选出的点的比例。

函数步骤解释

  1. 计算总采样点数

    num_boxes = logits.shape[0]
    num_points_sampled = int(num_points * oversample_ratio)
    

    num_boxes 是指预测的盒子数量,num_points_sampled 是经过过采样之后的总采样点数。

  2. 生成随机点的坐标

    point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
    

    在 [0, 1] * [0, 1] 空间内生成随机点的坐标。

  3. 获取这些随机点的预测值

    point_logits = sample_point(logits, point_coordinates, align_corners=False)
    

    对随机点的坐标进行采样,获取它们的预测 logit 值。

  4. 计算这些点的不确定性

    point_uncertainties = uncertainty_function(point_logits)
    

    使用 uncertainty_function 计算这些点的不确定性。

  5. 确定不确定性采样和随机采样的点数

    num_uncertain_points = int(importance_sample_ratio * num_points)
    num_random_points = num_points - num_uncertain_points
    

    根据 importance_sample_ratio 确定通过不确定性采样的点数 num_uncertain_points,以及剩余的随机采样点数 num_random_points

  6. 选择不确定性最高的点

    idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
    shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
    idx += shift[:, None]
    point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
    

    使用 torch.topk 函数选择每个盒子中不确定性最高的 num_uncertain_points 个点,并获取它们的坐标。

  7. 添加随机点

    if num_random_points > 0:point_coordinates = torch.cat([point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],dim=1,)
    

    如果需要添加随机采样点,将它们与不确定性采样点合并。

  8. 返回采样点的坐标

    return point_coordinates
    

    最终返回所有采样点的坐标。

关键代码解读

1. 偏移量的生成
 shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)

这行代码的目的是为每个 box 生成一个偏移量(shift),用于转换局部索引为全局索引。

  • torch.arange(num_boxes, dtype=torch.long, device=logits.device) 生成一个从 0 到 num_boxes-1 的张量。
  • num_points_sampled 是每个 box 中采样的点的数量。
  • 乘法操作 num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) 为每个 box 生成一个偏移量。例如,假设 num_points_sampled 为 100,那么生成的偏移量张量为 [0, 100, 200, 300, ...]

这些偏移量将用于将局部索引(即每个 box 内的索引)转换为全局索引(即在整个 point_coordinates 中的索引)。

2. 局部索引转换为全局索引
  idx += shift[:, None]

这行代码将局部索引转换为全局索引。

  • idxtorch.topk 返回的不确定性最高的点的局部索引,形状为 [num_boxes, num_uncertain_points]
  • shift[:, None] 的形状是 [num_boxes, 1],通过这种方式将每个 box 的偏移量广播到与 idx 的形状匹配。

通过将 shift 加到 idx 上,每个 box 的局部索引将变成全局索引。例如,如果第一个 box 的偏移量为 100,那么第一个 box 内的局部索引 [0, 1, 2, ...] 将变为 [100, 101, 102, ...]

总结

通过 sample_points_using_uncertainty 函数,我们可以有效地选择不确定性高的点进行训练,提高模型在这些关键点上的表现,同时减少确定性高的点的计算开销。这种不确定性采样方法结合了重要性采样和随机采样,确保了模型训练的高效性和鲁棒性。

这篇关于mask2former利用不确定性采样点选择提高模型性能的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056225

相关文章

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

Python如何使用__slots__实现节省内存和性能优化

《Python如何使用__slots__实现节省内存和性能优化》你有想过,一个小小的__slots__能让你的Python类内存消耗直接减半吗,没错,今天咱们要聊的就是这个让人眼前一亮的技巧,感兴趣的... 目录背景:内存吃得满满的类__slots__:你的内存管理小助手举个大概的例子:看看效果如何?1.

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

Redis中高并发读写性能的深度解析与优化

《Redis中高并发读写性能的深度解析与优化》Redis作为一款高性能的内存数据库,广泛应用于缓存、消息队列、实时统计等场景,本文将深入探讨Redis的读写并发能力,感兴趣的小伙伴可以了解下... 目录引言一、Redis 并发能力概述1.1 Redis 的读写性能1.2 影响 Redis 并发能力的因素二、

Golang中拼接字符串的6种方式性能对比

《Golang中拼接字符串的6种方式性能对比》golang的string类型是不可修改的,对于拼接字符串来说,本质上还是创建一个新的对象将数据放进去,主要有6种拼接方式,下面小编就来为大家详细讲讲吧... 目录拼接方式介绍性能对比测试代码测试结果源码分析golang的string类型是不可修改的,对于拼接字

基于Python实现多语言朗读与单词选择测验

《基于Python实现多语言朗读与单词选择测验》在数字化教育日益普及的今天,开发一款能够支持多语言朗读和单词选择测验的程序,对于语言学习者来说无疑是一个巨大的福音,下面我们就来用Python实现一个这... 目录一、项目概述二、环境准备三、实现朗读功能四、实现单词选择测验五、创建图形用户界面六、运行程序七、

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

mysql线上查询之前要性能调优的技巧及示例

《mysql线上查询之前要性能调优的技巧及示例》文章介绍了查询优化的几种方法,包括使用索引、避免不必要的列和行、有效的JOIN策略、子查询和派生表的优化、查询提示和优化器提示等,这些方法可以帮助提高数... 目录避免不必要的列和行使用有效的JOIN策略使用子查询和派生表时要小心使用查询提示和优化器提示其他常

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot