Co Attention注意力机制实现

2024-04-24 20:32

本文主要是介绍Co Attention注意力机制实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

“Hierarchical Question-Image Co-Attention for Visual Question Answering”中的图像和文本间的Co Attention协同注意力实现

参考:

https://github.com/SkyOL5/VQA-CoAttention/blob/master/coatt/coattention_net.py

https://github.com/Zhangtd/Models-reproducing/blob/master/NIPS2016/selfDef.py

Co Attention示意图如下:

有两种实现方式,分别为Parallel co-attention mechanism和Alternating co-attention mechanism

基于PyTorch实现Parallel co-attention mechanism,代码如下:

from typing import Dict, Optionalimport numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import Tensordef create_src_lengths_mask(batch_size: int, src_lengths: Tensor, max_src_len: Optional[int] = None
):"""Generate boolean mask to prevent attention beyond the end of sourceInputs:batch_size : intsrc_lengths : [batch_size] of sentence lengthsmax_src_len: Optionally override max_src_len for the maskOutputs:[batch_size, max_src_len]"""if max_src_len is None:max_src_len = int(src_lengths.max())src_indices = torch.arange(0, max_src_len).unsqueeze(0).type_as(src_lengths)src_indices = src_indices.expand(batch_size, max_src_len)src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_src_len)# returns [batch_size, max_seq_len]return (src_indices < src_lengths).int().detach()def masked_softmax(scores, src_lengths, src_length_masking=True):"""Apply source length masking then softmax.Input and output have shape bsz x src_len"""if src_length_masking:bsz, max_src_len = scores.size()# print('bsz:', bsz)# compute maskssrc_mask = create_src_lengths_mask(bsz, src_lengths)# Fill pad positions with -infscores = scores.masked_fill(src_mask == 0, -np.inf)# Cast to float and then back again to prevent loss explosion under fp16.return F.softmax(scores.float(), dim=-1).type_as(scores)class ParallelCoAttentionNetwork(nn.Module):def __init__(self, hidden_dim, co_attention_dim, src_length_masking=True):super(ParallelCoAttentionNetwork, self).__init__()self.hidden_dim = hidden_dimself.co_attention_dim = co_attention_dimself.src_length_masking = src_length_maskingself.W_b = nn.Parameter(torch.randn(self.hidden_dim, self.hidden_dim))self.W_v = nn.Parameter(torch.randn(self.co_attention_dim, self.hidden_dim))self.W_q = nn.Parameter(torch.randn(self.co_attention_dim, self.hidden_dim))self.w_hv = nn.Parameter(torch.randn(self.co_attention_dim, 1))self.w_hq = nn.Parameter(torch.randn(self.co_attention_dim, 1))def forward(self, V, Q, Q_lengths):""":param V: batch_size * hidden_dim * region_num, eg B x 512 x 196:param Q: batch_size * seq_len * hidden_dim, eg B x L x 512:param Q_lengths: batch_size:return:batch_size * 1 * region_num, batch_size * 1 * seq_len,batch_size * hidden_dim, batch_size * hidden_dim"""# (batch_size, seq_len, region_num)C = torch.matmul(Q, torch.matmul(self.W_b, V))# (batch_size, co_attention_dim, region_num)H_v = nn.Tanh()(torch.matmul(self.W_v, V) + torch.matmul(torch.matmul(self.W_q, Q.permute(0, 2, 1)), C))# (batch_size, co_attention_dim, seq_len)H_q = nn.Tanh()(torch.matmul(self.W_q, Q.permute(0, 2, 1)) + torch.matmul(torch.matmul(self.W_v, V), C.permute(0, 2, 1)))# (batch_size, 1, region_num)a_v = F.softmax(torch.matmul(torch.t(self.w_hv), H_v), dim=2)# (batch_size, 1, seq_len)a_q = F.softmax(torch.matmul(torch.t(self.w_hq), H_q), dim=2)# # (batch_size, 1, seq_len)masked_a_q = masked_softmax(a_q.squeeze(1), Q_lengths, self.src_length_masking).unsqueeze(1)# (batch_size, hidden_dim)v = torch.squeeze(torch.matmul(a_v, V.permute(0, 2, 1)))# (batch_size, hidden_dim)q = torch.squeeze(torch.matmul(masked_a_q, Q))return a_v, masked_a_q, v, q

测试代码如下:

pcan = ParallelCoAttentionNetwork(6, 5)
v = torch.randn((5, 6, 10))
q = torch.randn(5, 8, 6)
q_lens = torch.LongTensor([3, 4, 5, 8, 2])
a_v, a_q, v, q = pcan(v, q, q_lens)
print(a_v)
print(a_v.shape)
print(a_q)
print(a_q.shape)
print(v)
print(v.shape)
print(q)
print(q.shape)

效果如下:

tensor([[[9.2527e-04, 1.1542e-03, 1.1542e-03, 1.1542e-03, 2.0009e-02,9.2527e-04, 4.0845e-02, 8.8328e-01, 1.1958e-03, 4.9358e-02]],[[4.5501e-01, 8.6522e-02, 8.6522e-02, 1.7235e-05, 3.8831e-03,2.5070e-04, 9.0637e-05, 4.0010e-03, 2.0196e-03, 3.6169e-01]],[[8.8455e-03, 7.2149e-04, 1.7595e-04, 2.1307e-04, 7.0610e-01,1.3427e-01, 4.3360e-04, 4.0731e-02, 4.0731e-02, 6.7774e-02]],[[4.0013e-01, 2.3081e-02, 3.8406e-02, 4.3583e-03, 9.9425e-05,3.8398e-02, 9.9425e-05, 9.4912e-02, 4.0013e-01, 3.9162e-04]],[[3.1121e-02, 8.0567e-05, 4.0445e-01, 1.4391e-03, 8.0567e-05,4.0445e-01, 7.6909e-02, 2.4837e-04, 4.3044e-03, 7.6909e-02]]],grad_fn=<SoftmaxBackward>)
torch.Size([5, 1, 10])
tensor([[[0.3466, 0.3267, 0.3267, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.2256, 0.3276, 0.2237, 0.2232, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.1761, 0.2254, 0.2254, 0.1823, 0.1908, 0.0000, 0.0000, 0.0000]],[[0.1292, 0.1411, 0.1411, 0.1100, 0.1292, 0.1100, 0.1101, 0.1292]],[[0.5284, 0.4716, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],grad_fn=<UnsqueezeBackward0>)
torch.Size([5, 1, 8])
tensor([[-0.7862,  1.0180,  0.1585,  0.4961, -1.5916, -0.3553],[ 0.3624, -0.2036,  0.2993, -0.4440,  0.2494,  1.4896],[ 0.1695, -0.2286,  0.4431,  0.6027, -1.6116,  0.0566],[ 0.2004,  0.8219, -0.2115, -0.6428,  0.3486,  1.3802],[ 1.4024, -0.1860,  0.1685,  0.2352, -0.4956,  1.0010]],grad_fn=<SqueezeBackward0>)
torch.Size([5, 6])
tensor([[ 0.3757,  0.1662,  0.2181,  0.0787,  0.0110, -0.5938],[-0.6106,  0.4000,  0.6068, -0.4054,  0.0193, -0.1147],[ 0.3877, -0.1800,  1.2430, -0.4881, -0.3598, -0.3592],[-0.3799, -0.3262,  0.0745, -0.2856,  0.0221, -0.1749],[ 0.1159, -0.4949, -0.5576, -0.6870, -1.2895,  0.0421]],grad_fn=<SqueezeBackward0>)
torch.Size([5, 6])

这篇关于Co Attention注意力机制实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/932765

相关文章

Java的栈与队列实现代码解析

《Java的栈与队列实现代码解析》栈是常见的线性数据结构,栈的特点是以先进后出的形式,后进先出,先进后出,分为栈底和栈顶,栈应用于内存的分配,表达式求值,存储临时的数据和方法的调用等,本文给大家介绍J... 目录栈的概念(Stack)栈的实现代码队列(Queue)模拟实现队列(双链表实现)循环队列(循环数组

C++如何通过Qt反射机制实现数据类序列化

《C++如何通过Qt反射机制实现数据类序列化》在C++工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作,所以本文就来聊聊C++如何通过Qt反射机制实现数据类序列化吧... 目录设计预期设计思路代码实现使用方法在 C++ 工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作。由于数据类

Python实现图片分割的多种方法总结

《Python实现图片分割的多种方法总结》图片分割是图像处理中的一个重要任务,它的目标是将图像划分为多个区域或者对象,本文为大家整理了一些常用的分割方法,大家可以根据需求自行选择... 目录1. 基于传统图像处理的分割方法(1) 使用固定阈值分割图片(2) 自适应阈值分割(3) 使用图像边缘检测分割(4)

Android实现在线预览office文档的示例详解

《Android实现在线预览office文档的示例详解》在移动端展示在线Office文档(如Word、Excel、PPT)是一项常见需求,这篇文章为大家重点介绍了两种方案的实现方法,希望对大家有一定的... 目录一、项目概述二、相关技术知识三、实现思路3.1 方案一:WebView + Office Onl

C# foreach 循环中获取索引的实现方式

《C#foreach循环中获取索引的实现方式》:本文主要介绍C#foreach循环中获取索引的实现方式,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、手动维护索引变量二、LINQ Select + 元组解构三、扩展方法封装索引四、使用 for 循环替代

Spring Security+JWT如何实现前后端分离权限控制

《SpringSecurity+JWT如何实现前后端分离权限控制》本篇将手把手教你用SpringSecurity+JWT搭建一套完整的登录认证与权限控制体系,具有很好的参考价值,希望对大家... 目录Spring Security+JWT实现前后端分离权限控制实战一、为什么要用 JWT?二、JWT 基本结构

Java实现优雅日期处理的方案详解

《Java实现优雅日期处理的方案详解》在我们的日常工作中,需要经常处理各种格式,各种类似的的日期或者时间,下面我们就来看看如何使用java处理这样的日期问题吧,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言一、日期的坑1.1 日期格式化陷阱1.2 时区转换二、优雅方案的进阶之路2.1 线程安全重构2

Android实现两台手机屏幕共享和远程控制功能

《Android实现两台手机屏幕共享和远程控制功能》在远程协助、在线教学、技术支持等多种场景下,实时获得另一部移动设备的屏幕画面,并对其进行操作,具有极高的应用价值,本项目旨在实现两台Android手... 目录一、项目概述二、相关知识2.1 MediaProjection API2.2 Socket 网络

使用Python实现图像LBP特征提取的操作方法

《使用Python实现图像LBP特征提取的操作方法》LBP特征叫做局部二值模式,常用于纹理特征提取,并在纹理分类中具有较强的区分能力,本文给大家介绍了如何使用Python实现图像LBP特征提取的操作方... 目录一、LBP特征介绍二、LBP特征描述三、一些改进版本的LBP1.圆形LBP算子2.旋转不变的LB

Redis消息队列实现异步秒杀功能

《Redis消息队列实现异步秒杀功能》在高并发场景下,为了提高秒杀业务的性能,可将部分工作交给Redis处理,并通过异步方式执行,Redis提供了多种数据结构来实现消息队列,总结三种,本文详细介绍Re... 目录1 Redis消息队列1.1 List 结构1.2 Pub/Sub 模式1.3 Stream 结