multiheadattention类原理及源码理解

2023-11-01 19:28

本文主要是介绍multiheadattention类原理及源码理解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

网络找的一段代码如下:

class MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):"Take in model size and number of heads."super(MultiHeadedAttention, self).__init__()assert d_model % h == 0# We assume d_v always equals d_kself.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.attn = Noneself.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))]#这段代码首先使用zip函数,将self.linears和(query, key, value)这两个列表打包成一个元组列表,其中每个元组包含一个线性层对象和一个输入张量#对遍历的每一个Linear层,对query key value分别计算,结果放在query key value中输出# 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \.view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

python、pytorch、人工智能相关知识现阶段都是简单的了解,没有相关的实践。因此在学习的时候不要习惯性的扣代码细节。能把论文原理和代码逻辑对应即可、能总结代码块重点内容即可。

transformer中self-attention就是对一个输入序列计算每个位置的注意力,每个位置在论文原文中用d_model(512)维表示,多头就是每个位置用h(原文中8个)个头计算,这样每个头计算一个位置中的64维特征。

自注意力机制有什么好处呢?

自注意力机制的目的是让模型能够同时关注输入序列中的不同位置和信息,从而捕捉序列中的复杂模式和关系。通过计算每个位置的向量与其他位置的向量之间的相似度或相关性,模型可以学习到序列中每个元素对于输出结果的重要性,从而给予不同的权重。

为什么要使用多头呢?下面是我找到的解释:

多头计算可以让模型同时关注输入序列中的不同方面和细节,从而增强模型的表达能力和学习能力。每个注意力头可以捕捉输入序列中的不同模式和关系,而最终的线性变换可以将这些信息融合在一起。
多头计算可以降低模型的复杂度和计算成本。对于较大的 d_model 来说,如果只使用单头计算,那么 QK^T 的结果会非常大,导致 softmax 函数的梯度非常小,不利于网络的训练。而使用多头计算,可以将 d_model 分割成 h 个较小的子空间,从而减少计算量和内存消耗34。
多头计算还可以
提高模型的可解释性和泛化能力
。我们可以从模型中检查不同注意力头的分布,观察模型是如何关注不同位置和信息的。各个注意力头可以学会执行不同的任务,例如语法分析、实体识别等

MultiHeadedAttention类还做了什么事情?
1、通过4个线性层(通常是4)计算得到Q K V矩阵
在transformer中,Q、K、V是通过四个线性层得到的,分别是:
Q = XW^Q ,其中X是embedding输入矩阵,W^Q 是一个可训练的参数矩阵,大小为(d_model* d_model),用于将X映射到Q空间。
K = XW^K ,其中X是embedding输入矩阵,W^K 是一个可训练的参数矩阵,大小为(d_model* d_model),用于将X映射到K空间。
V = XW^V ,其中Xembedding是输入矩阵,W^V 是一个可训练的参数矩阵,大小为(d_model* d_model)用于将X映射到V空间。

这篇关于multiheadattention类原理及源码理解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/325201

相关文章

Java调用C++动态库超详细步骤讲解(附源码)

《Java调用C++动态库超详细步骤讲解(附源码)》C语言因其高效和接近硬件的特性,时常会被用在性能要求较高或者需要直接操作硬件的场合,:本文主要介绍Java调用C++动态库的相关资料,文中通过代... 目录一、直接调用C++库第一步:动态库生成(vs2017+qt5.12.10)第二步:Java调用C++

Java编译生成多个.class文件的原理和作用

《Java编译生成多个.class文件的原理和作用》作为一名经验丰富的开发者,在Java项目中执行编译后,可能会发现一个.java源文件有时会产生多个.class文件,从技术实现层面详细剖析这一现象... 目录一、内部类机制与.class文件生成成员内部类(常规内部类)局部内部类(方法内部类)匿名内部类二、

Python中随机休眠技术原理与应用详解

《Python中随机休眠技术原理与应用详解》在编程中,让程序暂停执行特定时间是常见需求,当需要引入不确定性时,随机休眠就成为关键技巧,下面我们就来看看Python中随机休眠技术的具体实现与应用吧... 目录引言一、实现原理与基础方法1.1 核心函数解析1.2 基础实现模板1.3 整数版实现二、典型应用场景2

Python实现无痛修改第三方库源码的方法详解

《Python实现无痛修改第三方库源码的方法详解》很多时候,我们下载的第三方库是不会有需求不满足的情况,但也有极少的情况,第三方库没有兼顾到需求,本文将介绍几个修改源码的操作,大家可以根据需求进行选择... 目录需求不符合模拟示例 1. 修改源文件2. 继承修改3. 猴子补丁4. 追踪局部变量需求不符合很

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

Spring 中 BeanFactoryPostProcessor 的作用和示例源码分析

《Spring中BeanFactoryPostProcessor的作用和示例源码分析》Spring的BeanFactoryPostProcessor是容器初始化的扩展接口,允许在Bean实例化前... 目录一、概览1. 核心定位2. 核心功能详解3. 关键特性二、Spring 内置的 BeanFactory

JAVA封装多线程实现的方式及原理

《JAVA封装多线程实现的方式及原理》:本文主要介绍Java中封装多线程的原理和常见方式,通过封装可以简化多线程的使用,提高安全性,并增强代码的可维护性和可扩展性,需要的朋友可以参考下... 目录前言一、封装的目标二、常见的封装方式及原理总结前言在 Java 中,封装多线程的原理主要围绕着将多线程相关的操

kotlin中的模块化结构组件及工作原理

《kotlin中的模块化结构组件及工作原理》本文介绍了Kotlin中模块化结构组件,包括ViewModel、LiveData、Room和Navigation的工作原理和基础使用,本文通过实例代码给大家... 目录ViewModel 工作原理LiveData 工作原理Room 工作原理Navigation 工

Java的volatile和sychronized底层实现原理解析

《Java的volatile和sychronized底层实现原理解析》文章详细介绍了Java中的synchronized和volatile关键字的底层实现原理,包括字节码层面、JVM层面的实现细节,以... 目录1. 概览2. Synchronized2.1 字节码层面2.2 JVM层面2.2.1 ente

MySQL的隐式锁(Implicit Lock)原理实现

《MySQL的隐式锁(ImplicitLock)原理实现》MySQL的InnoDB存储引擎中隐式锁是一种自动管理的锁,用于保证事务在行级别操作时的数据一致性和安全性,本文主要介绍了MySQL的隐式锁... 目录1. 背景:什么是隐式锁?2. 隐式锁的工作原理3. 隐式锁的类型4. 隐式锁的实现与源代码分析4