haiku实现门控多头注意力模块

2024-01-10 09:52

本文主要是介绍haiku实现门控多头注意力模块,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在多头注意力机制中,通常输入的数据包括查询(Q)、键(K)和值(V)。这些数据的维度以及权重矩阵的维度在多头注意力机制中扮演关键角色。下面对数据及权重的维度进行解释:

  1. 输入数据(Queries, Keys, Values):

    • Queries (Q): 表示待查询的信息,通常对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, q_dim),其中 q_dim 是查询向量的维度。
    • Keys (K): 表示用于计算注意力分数的信息,也通常对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, key_dim),其中 key_dim 是键向量的维度。
    • Values (V): 表示待加权求和的信息,同样对应输入序列的每个位置。其维度通常为 (batch_size, seq_length, value_dim),其中 value_dim 是值向量的维度。
  2. 权重矩阵:

    • 查询权重矩阵 (Q_weights): 用于对查询(Q)进行线性变换,将其映射到多个注意力头的维度。其维度通常为 (q_dim, num_heads, head_dim),其中 num_heads 是注意力头的数量,head_dim 是每个注意力头的维度。
    • 键权重矩阵 (K_weights): 用于对键(K)进行线性变换,同样映射到多个注意力头的维度。其维度通常为 (key_dim, num_heads, head_dim)。
    • 值权重矩阵 (V_weights): 用于对值(V)进行线性变换,映射到多个注意力头的维度。其维度通常为 (value_dim, num_heads, head_dim)。
def glorot_uniform():return hk.initializers.VarianceScaling(scale=1.0,mode='fan_avg',distribution='uniform')def stable_softmax(logits: jax.Array) -> jax.Array:"""Numerically stable softmax for (potential) bfloat 16."""if logits.dtype == jnp.float32:output = jax.nn.softmax(logits)elif logits.dtype == jnp.bfloat16:# Need to explicitly do softmax in float32 to avoid numerical issues# with large negatives. Large negatives can occur if trying to mask# by adding on large negative logits so that things softmax to zero.output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)else:raise ValueError(f'Unexpected input dtype {logits.dtype}')return outputclass Attention(hk.Module):"""Multihead attention."""def __init__(self, config, global_config, output_dim, name='attention'):super().__init__(name=name)self.config = configself.global_config = global_configself.output_dim = output_dimdef __call__(self, q_data, m_data, mask, nonbatched_bias=None):"""Builds Attention module.Arguments:q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].m_data: A tensor of memories from which the keys and values areprojected, shape [batch_size, N_keys, m_channels].mask: A mask for the attention, shape [batch_size, N_queries, N_keys].nonbatched_bias: Shared bias, shape [N_queries, N_keys].Returns:A float32 tensor of shape [batch_size, N_queries, output_dim]."""# Sensible default for when the config keys are missingkey_dim = self.config.get('key_dim', int(q_data.shape[-1]))value_dim = self.config.get('value_dim', int(m_data.shape[-1]))num_head = self.config.num_headassert key_dim % num_head == 0assert value_dim % num_head == 0key_dim = key_dim // num_headvalue_dim = value_dim // num_head# weights维度(数据最后一维的维度数,注意力头数量,每个注意力头映射的数据维度)q_weights = hk.get_parameter('query_w', shape=(q_data.shape[-1], num_head, key_dim),dtype=q_data.dtype,init=glorot_uniform())k_weights = hk.get_parameter('key_w', shape=(m_data.shape[-1], num_head, key_dim),dtype=q_data.dtype,init=glorot_uniform())v_weights = hk.get_parameter('value_w', shape=(m_data.shape[-1], num_head, value_dim),dtype=q_data.dtype,init=glorot_uniform())# bqa: 输入张量 q_data 的轴的标记。(batch_size, seq_length, q_dim)# 'b' :batch 维度,'q':查询序列维度,'a' 查询向量的维度。所以,'bqa' 表示 q_data 的三个轴。# ahc:查询权重矩阵的形状, a:查询向量的维度,h:注意力头的数量,c: 每个注意力头中查询的维度。# key_dim**(-0.5) 注意力缩放,避免注意力分数过大或过小# jnp.einsum:Einstein Summation Notation(爱因斯坦求和约定)。# 一种紧凑、灵活的方式来指定和计算张量的乘积、求和和转置等操作。q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)# 注意力分数,计算每个查询(q)和键(k)之间的点积,以获得注意力分数。# 结果维度为bhqk (batch_size, num_heads, num_q, num_k), # num_q/num_k为查询/键的数量,一般为 seq_length。logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)if nonbatched_bias is not None:logits += jnp.expand_dims(nonbatched_bias, axis=0)# 注意力分数中加入masklogits = jnp.where(mask, logits, _SOFTMAX_MASK)# 对注意力分数进行softmax操作,我们得到每个位置对输入序列的权重分配。weights = stable_softmax(logits)# 注意力分数对值进行加权求和,得到多头注意力机制的输出# 两个向量的点积可以用于度量它们之间的相似性。如果两个向量越相似,它们的点积就越大weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)if self.global_config.zero_init:init = hk.initializers.Constant(0.0)else:init = glorot_uniform()# 带有bias的门控注意力if self.config.gating:gating_weights = hk.get_parameter('gating_w',shape=(q_data.shape[-1], num_head, value_dim),dtype=q_data.dtype,init=hk.initializers.Constant(0.0))gating_bias = hk.get_parameter('gating_b',shape=(num_head, value_dim),dtype=q_data.dtype,init=hk.initializers.Constant(1.0))gate_values = jnp.einsum('bqc, chv->bqhv', q_data,gating_weights) + gating_biasgate_values = jax.nn.sigmoid(gate_values)# ⊙ 对应元素相乘weighted_avg *= gate_valueso_weights = hk.get_parameter('output_w', shape=(num_head, value_dim, self.output_dim),dtype=q_data.dtype,init=init)o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),dtype=q_data.dtype,init=hk.initializers.Constant(0.0))# 线性变换到输出维度大小output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_biasreturn output

这篇关于haiku实现门控多头注意力模块的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/590395

相关文章

Vue项目的甘特图组件之dhtmlx-gantt使用教程和实现效果展示(推荐)

《Vue项目的甘特图组件之dhtmlx-gantt使用教程和实现效果展示(推荐)》文章介绍了如何使用dhtmlx-gantt组件来实现公司的甘特图需求,并提供了一个简单的Vue组件示例,文章还分享了一... 目录一、首先 npm 安装插件二、创建一个vue组件三、业务页面内 引用自定义组件:四、dhtmlx

Vue ElementUI中Upload组件批量上传的实现代码

《VueElementUI中Upload组件批量上传的实现代码》ElementUI中Upload组件批量上传通过获取upload组件的DOM、文件、上传地址和数据,封装uploadFiles方法,使... ElementUI中Upload组件如何批量上传首先就是upload组件 <el-upl

Node.js net模块的使用示例

《Node.jsnet模块的使用示例》本文主要介绍了Node.jsnet模块的使用示例,net模块支持TCP通信,处理TCP连接和数据传输,具有一定的参考价值,感兴趣的可以了解一下... 目录简介引入 net 模块核心概念TCP (传输控制协议)Socket服务器TCP 服务器创建基本服务器服务器配置选项服

Docker部署Jenkins持续集成(CI)工具的实现

《Docker部署Jenkins持续集成(CI)工具的实现》Jenkins是一个流行的开源自动化工具,广泛应用于持续集成(CI)和持续交付(CD)的环境中,本文介绍了使用Docker部署Jenkins... 目录前言一、准备工作二、设置变量和目录结构三、配置 docker 权限和网络四、启动 Jenkins

Python3脚本实现Excel与TXT的智能转换

《Python3脚本实现Excel与TXT的智能转换》在数据处理的日常工作中,我们经常需要将Excel中的结构化数据转换为其他格式,本文将使用Python3实现Excel与TXT的智能转换,需要的可以... 目录场景应用:为什么需要这种转换技术解析:代码实现详解核心代码展示改进点说明实战演练:从Excel到

如何使用CSS3实现波浪式图片墙

《如何使用CSS3实现波浪式图片墙》:本文主要介绍了如何使用CSS3的transform属性和动画技巧实现波浪式图片墙,通过设置图片的垂直偏移量,并使用动画使其周期性地改变位置,可以创建出动态且具有波浪效果的图片墙,同时,还强调了响应式设计的重要性,以确保图片墙在不同设备上都能良好显示,详细内容请阅读本文,希望能对你有所帮助...

C# string转unicode字符的实现

《C#string转unicode字符的实现》本文主要介绍了C#string转unicode字符的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随... 目录1. 获取字符串中每个字符的 Unicode 值示例代码:输出:2. 将 Unicode 值格式化

python安装whl包并解决依赖关系的实现

《python安装whl包并解决依赖关系的实现》本文主要介绍了python安装whl包并解决依赖关系的实现,文中通过图文示例介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面... 目录一、什么是whl文件?二、我们为什么需要使用whl文件来安装python库?三、我们应该去哪儿下

Python脚本实现图片文件批量命名

《Python脚本实现图片文件批量命名》这篇文章主要为大家详细介绍了一个用python第三方库pillow写的批量处理图片命名的脚本,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录前言源码批量处理图片尺寸脚本源码GUI界面源码打包成.exe可执行文件前言本文介绍一个用python第三方库pi

Java中将异步调用转为同步的五种实现方法

《Java中将异步调用转为同步的五种实现方法》本文介绍了将异步调用转为同步阻塞模式的五种方法:wait/notify、ReentrantLock+Condition、Future、CountDownL... 目录异步与同步的核心区别方法一:使用wait/notify + synchronized代码示例关键