ML-Decoder: Scalable and Versatile Classification Head

2024-03-31 05:12

本文主要是介绍ML-Decoder: Scalable and Versatile Classification Head,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、引言

论文链接:https://openaccess.thecvf.com/content/WACV2023/papers/Ridnik_ML-Decoder_Scalable_and_Versatile_Classification_Head_WACV_2023_paper.pdf

        因为 transformer 解码器分类头[1] 在少类别多标签分类数据集上表现得很好,但由于其查询复杂度为 O(n^2),n 为类别数量,故 transformer 解码器分类头对于多类别数据集是不可行的,且 transformer 解码器分类头只适用于多标签分类任务,故 Tal Ridnik 等引入了一种新的基于多头注意力机制的分类头——ML-Decoder[2]。ML-Decoder 可以用于单标签分类、多标签分类和多标签 ZSL(zero shot learning) 任务,它提供更好的精度-速度 trade-off,可以用于上万类别的数据集,可以作为各种分类头的 drop-in 替代品,结合词查询可以用于 ZSL。

2、方法

        ML-Decoder 流如图 1 右所示,相对于  transformer 解码器分类头,ML-Decoder 有一下改变。

图1  transformer-decoder vs. ML-Decoder

2.1  移除自注意力机制

        通过删除自注意力机制将 ML-Decoder 的查询复杂度由 O(n^2) 降至 O(n),并未影响表示能力。

2.2  组解码

        为了使查询数量与类别数量无关,使用固定的 k 组查询,而不是一个类别对应一个查询。在前馈神经网络后,通过组全连接层在将每个组查询扩展到 g=n/k 个输出的同时池化嵌入维度。如图 2 所示。

图2  组全连接方案(g=4)

2.3  固定查询        

        查询总是被输入到一个多头注意力层,该注意力层会先对查询应用一个可学习的投影计算。因此,将查询权重设置为可学习的是多余的——可学习的投影可以将任何固定值查询转换为可学习查询获得的任何值。

3、模块介绍

3.1  Cross-Attention

        Cross-Attention 的核心其实就是多头注意力机制,输入的 q 为固定查询,k 和 v 均为图像嵌入。Cross-Attention 和 Feed-Forward 模块构成所谓的 TransformerDecoder(Layer),python 代码如下所示:

class TransformerDecoder(nn.Module):def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1) -> None:super().__init__()self.dropout = nn.Dropout(dropout)self.norm0 = nn.LayerNorm(d_model)self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)# Implementation of Feedforward modelself.feed_forward = nn.Sequential(nn.LayerNorm(d_model),nn.Linear(d_model, dim_feedforward),nn.ReLU(),nn.Dropout(dropout),nn.Linear(dim_feedforward, d_model))self.norm1 = nn.LayerNorm(d_model)def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:tgt = tgt + self.dropout(tgt)tgt = self.norm0(tgt)tgt0 = self.multihead_attn(tgt, memory, memory)[0]tgt = tgt + self.dropout(tgt0)tgt0 = self.feed_forward(tgt)tgt = tgt + self.dropout(tgt0)return self.norm1(tgt)

3.2  Group Fully Connected Pooling  

        Group Fully Connected Pooling的目的是将每个组查询扩展到 g=n/k 个输出的同时池化嵌入维度。即将每组查询结果与对应的可学习的 (hidde_dim, g) 维矩阵相乘,python 代码如下所示:

class GroupFC(object):def __init__(self, groups: int):self.groups = groupsdef __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):"""计算每组类的 logits 值(未加偏置):param h: shape=(b, groups, hidden_dim):param duplicate_pooling: shape=(groups, hidden_dim, duplicate_factor), duplicate_factor 每组的类别数:param out_extrap: shape=(b, groups, duplicate_factor):return:"""for i in range(h.shape[1]):h_i = h[:, i, :]w_i = duplicate_pooling[i, :, :]out_extrap[:, i, :] = torch.matmul(h_i, w_i)

4、总结

        作者开源的 ML-Decoder 的 python 实现代码在:https://github.com/Alibaba-MIIL/ML_Decoder/blob/main/src_files/ml_decoder/ml_decoder.py

        论文[2] 在 paper with code 上的战绩如图 3 所示,表现还是不错的。

图3  来自论文[2] 的结果

        由于当参数 zsl != 0 时 wordvec_proj 的输入 query_embed = None,本人还未学习过 ZSL 领域,且使用该代码时报错(zsl = 0,当然应该是我的原因,但懒得排错了),于是参考作者的代码写了一个 MLDecoder 类(只考虑 zsl = 0),剩下的代码如下所示。

class MLDecoder(nn.Module):"""Args:groups: 查询/类别组数hidden_dim: Transformer 解码器特征维度in_dim: 输入 tensor 特征维度(CNN 编码器输出为通道数,Transformer 编码器输出为最后一个维度)"""def __init__(self, num_classes, groups, in_dim=2048, hidden_dim=768, mlp_dim=2048, nhead=8, dropout=0.1):super().__init__()self.proj = nn.Linear(in_dim, hidden_dim)# non-learnable queriesself.query_embed = nn.Embedding(groups, hidden_dim)self.query_embed.requires_grad_(False)self.num_classes = num_classesself.decoder = TransformerDecoder(d_model=hidden_dim, nhead=nhead, dim_feedforward=mlp_dim, dropout=dropout)# group fully-connectedself.duplicate_factor = math.ceil(num_classes / groups)  # 每组类别数量,math.ceil: 向上取整self.duplicate_pooling = torch.nn.Parameter(torch.zeros((groups, hidden_dim, self.duplicate_factor)))self.duplicate_pooling_bias = torch.nn.Parameter(torch.zeros(num_classes))torch.nn.init.xavier_normal_(self.duplicate_pooling)self.group_fc = GroupFC(groups)def forward(self, x):# 确保解码器输入 shape 为 [b, h * w, c]if len(x.shape) == 4:x = x.flatten(2).transpose(1, 2)x = F.relu(self.proj(x), True)  # (b, h * w, hidden_dim)# Cross-Attention + Feed-Forwardquery_embed = self.query_embed.weight  # (groups, hidden_dim)# tensor.expend: 增大一个维度至指定大小, 不增大的维度为-1,例如将 shape 由 (b, n, c)->(b, 2n, c), 参数 size=(-1, 2n,-1)tgt = query_embed[None].expand(x.shape[0], -1, -1)  # (b, groups, hidden_dim)h = self.decoder(tgt, x)  # (b, groups, hidden_dim)# Group Fully Connected Poolingout_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)self.group_fc(h, self.duplicate_pooling, out_extrap)h_out = out_extrap.flatten(1)[:, :self.num_classes]  # (b, num_classes)return h_out + self.duplicate_pooling_bias

参考文献

[1] Shilong Liu, Lei Zhang, Xiao Yang, Hang Su, and Jun Zhu. Query2label: A simple transformer way to multi-label classification. arXiv preprint arXiv:2107.10834, 2021.

[2] Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben Baruch, and Asaf Noy. Ml-decoder: Scalable and versatile classification head. In IEEE/CVF Winter Conference on Applications of Computer Vision, WACV 2023, Waikoloa, HI, USA, January 2-7, 2023, pages 32–41. IEEE, 2023.

这篇关于ML-Decoder: Scalable and Versatile Classification Head的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/863472

相关文章

跟我一起玩《linux内核设计的艺术》第1章(四)——from setup.s to head.s,这回一定让main滚出来!(已解封)

看到书上1.3的大标题,以为马上就要见着main了,其实啊,还早着呢,光看setup.s和head.s的代码量就知道,跟bootsect.s没有可比性,真多……这确实需要包括我在内的大家多一些耐心,相信见着main后,大家的信心和干劲会上一个台阶,加油! 既然上篇已经玩转gdb,接下来的讲解肯定是边调试边分析书上的内容,纯理论讲解其实我并不在行。 setup.s: 目标:争取把setup.

ElasticSearch 6.1.1 通过Head插件,新建索引,添加文档,及其查询数据

ElasticSearch 6.1.1 通过Head插件,新建索引,添加文档,及其查询; 一、首先启动相关服务: 二、新建一个film索引: 三、建立映射: 1、通过Head插件: POST http://192.168.1.111:9200/film/_mapping/dongzuo/ {"properties": {"title": {"type":

Windows环境下ElasticSearch6.1.1版本安装Head插件

安装Head插件步骤如下: 1、下载node.js ,网址:https://nodejs.org/en/ 安装node到D盘。如D:\nodejs。 把NODE_HOME设置到环境变量里(安装包也可以自动加入PATH环境变量)。测试一下node是否生效: 2、安装grunt grunt是一个很方便的构建工具,可以进行打包压缩、测试、执行等等的工作,5.0里的head插件就是通过grunt

Linux中head和tail方法的使用

head -5 1.txt 从第五行开始到末尾 head –n 5 1.txt 同上 head –n +5 1.txt 同上 head –n -5 1.txt 除了最后五行的所有内容 少后五行   tail -5 1.txt 最后五行内容 tail –n 5 1.txt 同上 tail –n -5 1.txt 同上 tail –n +5 1.txt 从正数第五行到结尾的所有内容

【ML--05】第五课 如何做特征工程和特征选择

一、如何做特征工程? 1.排序特征:基于7W原始数据,对数值特征排序,得到1045维排序特征 2. 离散特征:将排序特征区间化(等值区间化、等量区间化),比如采用等量区间化为1-10,得到1045维离散特征 3. 计数特征:统计每一行中,离散特征1-10的个数,得到10维计数特征 4. 类别特征编码:将93维类别特征用one-hot编码 5. 交叉特征:特征之间两两融合,x+y、x-y、

【ML--04】第四课 logistic回归

1、什么是逻辑回归? 当要预测的y值不是连续的实数(连续变量),而是定性变量(离散变量),例如某个客户是否购买某件商品,这时线性回归模型不能直接作用,我们就需要用到logistic模型。 逻辑回归是一种分类的算法,它用给定的输入变量(X)来预测二元的结果(Y)(1/0,是/不是,真/假)。我们一般用虚拟变量来表示二元/类别结果。你可以把逻辑回归看成一种特殊的线性回归,只是因为最后的结果是类别变

【ML--13】聚类--层次聚类

一、基本概念 层次聚类不需要指定聚类的数目,首先它是将数据中的每个实例看作一个类,然后将最相似的两个类合并,该过程迭代计算只到剩下一个类为止,类由两个子类构成,每个子类又由更小的两个子类构成。 层次聚类方法对给定的数据集进行层次的分解,直到某种条件满足或者达到最大迭代次数。具体又可分为: 凝聚的层次聚类(AGNES算法):一种自底向上的策略,首先将每个对象作为一个簇,然后合并这些原子簇为越来

Convolutional Neural Networks for Sentence Classification论文解读

基本信息 作者Yoon Kimdoi发表时间2014期刊EMNLP网址https://doi.org/10.48550/arXiv.1408.5882 研究背景 1. What’s known 既往研究已证实 CV领域著名的CNN。 2. What’s new 创新点 将CNN应用于NLP,打破了传统NLP任务主要依赖循环神经网络(RNN)及其变体的局面。 用预训练的词向量(如word2v

less、more、head、tail命令解析集合

一、整体认识 命令使用优点常见使用方式less可以浏览文件内容,它可以用于查看大型文件,而不需要将整个文件加载到内存中。按下空格键向下翻页,按下b键向上翻页,按下q键退出浏览more类似于less,也是用于浏览文件内容的命令,但它不支持向上翻页。 按下空格键向下翻页,按下q键退出浏览。 head用于查看文件的前几行。head [选项] [文件名]。常用选项有-n,指定显示前几行,默认为显示前10

《Head First设计模式》之命令模式

命令模式就是将方法调用(Method invocation)封装起来。通过封装方法调用,我们可以把运算块包装成形,所以调用此运算的对象不需要关心事情是如何进行的,只要知道如何使用包装成形的方法来完成它就可以了。通过封装方法调用,可以用在以下场景:记录日志或者重复使用这些封装来实现撤销(undo)。     我对于命令模式的理解是:当我需要做一件事的时候,我只需要给出一个命令,这个命令中的