从DETR到Mask2Former(3):masked attention的attention map可视化

2024-01-13 20:36

本文主要是介绍从DETR到Mask2Former(3):masked attention的attention map可视化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Mask2Former的论文中有这样一张图,表示masked attenion比cross attention效果要好

那么这个attention map是怎么画出来的?

在mask2attention的源代码中 CrossAttentionLayer这个类中,在forward_post函数中做如下修改:

    def forward_post(self, tgt, memory,memory_mask: Optional[Tensor] = None,memory_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None,query_pos: Optional[Tensor] = None):tgt2, atten_weight = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),key=self.with_pos_embed(memory, pos),value=memory, attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask, average_attn_weights=False)atten_weight = atten_weight.squeeze().detach().cpu().numpy()head_num = 0selected_query_num = 0if atten_weight.shape[-1] == 21888:import matplotlib.pyplot as plt# 创建2行4列的图形fig, axs = plt.subplots(2, 4, figsize=(12, 6))# 使用8次for循环在每个子图中进行绘制for i in range(2):for j in range(4):atten_map = atten_weight[head_num, selected_query_num, :]atten_map = atten_map.reshape((128, 171))head_num += 0axs[i, j].imshow(atten_map)plt.show()tgt = tgt + self.dropout(tgt2)tgt = self.norm(tgt)return tgt

在 nn.MultiheadAttention 类实例的forward方法中,加入

average_attn_weights=False

得到每个注意力头的attention map,将attention_weight可视化,就得到了论文中的图片。

这篇关于从DETR到Mask2Former(3):masked attention的attention map可视化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/602703

相关文章

什么是 Flash Attention

Flash Attention 是 由 Tri Dao 和 Dan Fu 等人在2022年的论文 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 中 提出的, 论文可以从 https://arxiv.org/abs/2205.14135 页面下载,点击 View PDF 就可以下载。 下面我

Collection List Set Map的区别和联系

Collection List Set Map的区别和联系 这些都代表了Java中的集合,这里主要从其元素是否有序,是否可重复来进行区别记忆,以便恰当地使用,当然还存在同步方面的差异,见上一篇相关文章。 有序否 允许元素重复否 Collection 否 是 List 是 是 Set AbstractSet 否

Python:豆瓣电影商业数据分析-爬取全数据【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】

**爬取豆瓣电影信息,分析近年电影行业的发展情况** 本文是完整的数据分析展现,代码有完整版,包含豆瓣电影爬取的具体方式【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】   最近MBA在学习《商业数据分析》,大实训作业给了数据要进行数据分析,所以先拿豆瓣电影练练手,网络上爬取豆瓣电影TOP250较多,但对于豆瓣电影全数据的爬取教程很少,所以我自己做一版。 目

基于SSM+Vue+MySQL的可视化高校公寓管理系统

系统展示 管理员界面 宿管界面 学生界面 系统背景   当前社会各行业领域竞争压力非常大,随着当前时代的信息化,科学化发展,让社会各行业领域都争相使用新的信息技术,对行业内的各种相关数据进行科学化,规范化管理。这样的大环境让那些止步不前,不接受信息改革带来的信息技术的企业随时面临被淘汰,被取代的风险。所以当今,各个行业领域,不管是传统的教育行业

图神经网络框架DGL实现Graph Attention Network (GAT)笔记

参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 ——基础操作&消息传递 [3]Cora数据集介绍+python读取 一、DGL实现GAT分类机器学习论文 程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3]) 1. 程序 Ubuntu:18.04

「大数据分析」图形可视化,如何选择大数据可视化图形?

​图形可视化技术,在大数据分析中,是一个非常重要的关键部分。我们前期通过数据获取,数据处理,数据分析,得出结果,这些过程都是比较抽象的。如果是非数据分析专业人员,很难清楚我们这些工作,到底做了些什么事情。即使是专业人员,在不清楚项目,不了解业务规则,不熟悉技术细节的情况下。要搞清楚我们的大数据分析,这一系列过程,也是比较困难的。 我们在数据处理和分析完成后,一般来说,都需要形成结论报告。怎样让大

11Python的Pandas:可视化

Pandas本身并没有直接的可视化功能,但它与其他Python库(如Matplotlib和Seaborn)无缝集成,允许你快速创建各种图表和可视化。这里是一些使用Pandas数据进行可视化的常见方法: 1. 使用Matplotlib Pandas中的plot()方法实际上是基于Matplotlib的,你可以使用它来绘制各种基本图表,例如折线图、柱状图、散点图等。 import pandas

【全网最全】2024年数学建模国赛A题30页完整建模文档+17页成品论文+保奖matla代码+可视化图表等(后续会更新)

您的点赞收藏是我继续更新的最大动力! 一定要点击如下的卡片,那是获取资料的入口! 【全网最全】2024年数学建模国赛A题30页完整建模文档+17页成品论文+保奖matla代码+可视化图表等(后续会更新)「首先来看看目前已有的资料,还会不断更新哦~一次购买,后续不会再被收费哦,保证是全网最全资源,随着后续内容更新,价格会上涨,越早购买,价格越低,让大家再也不需要到处买断片资料啦~💰💸👋」�

Map

Map 是 Java 中用于存储键值对的集合接口。以下是对 Map 的详细介绍: 特点 键值对存储:每个元素包含一个键和一个值。 键唯一:键不能重复,但值可以重复。 无序/有序:根据具体实现,键值对的顺序可能无序(如 HashMap)或有序(如 TreeMap、LinkedHashMap)。 主要实现类 HashMap 基于哈希表,无序存储。 允许一个 null 键和多个 null 值。

Java中集合类Set、List和Map的区别

Java中的集合包括三大类,它们是Set、List和Map,它们都处于java.util包中,Set、List和Map都是接口,它们有各自的实现类。Set的实现类主要有HashSet和TreeSet,List的实现类主要有ArrayList,Map的实现类主要有HashMap和TreeMap。那么它们有什么区别呢? Set中的对象不按特定方式排序,并且没有重复对象。但它的有些实现类能对集合中的对