multiheadattention类原理及源码理解

2023-11-01 19:28

本文主要是介绍multiheadattention类原理及源码理解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

网络找的一段代码如下:

class MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):"Take in model size and number of heads."super(MultiHeadedAttention, self).__init__()assert d_model % h == 0# We assume d_v always equals d_kself.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.attn = Noneself.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))]#这段代码首先使用zip函数,将self.linears和(query, key, value)这两个列表打包成一个元组列表,其中每个元组包含一个线性层对象和一个输入张量#对遍历的每一个Linear层,对query key value分别计算,结果放在query key value中输出# 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \.view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

python、pytorch、人工智能相关知识现阶段都是简单的了解,没有相关的实践。因此在学习的时候不要习惯性的扣代码细节。能把论文原理和代码逻辑对应即可、能总结代码块重点内容即可。

transformer中self-attention就是对一个输入序列计算每个位置的注意力,每个位置在论文原文中用d_model(512)维表示,多头就是每个位置用h(原文中8个)个头计算,这样每个头计算一个位置中的64维特征。

自注意力机制有什么好处呢?

自注意力机制的目的是让模型能够同时关注输入序列中的不同位置和信息,从而捕捉序列中的复杂模式和关系。通过计算每个位置的向量与其他位置的向量之间的相似度或相关性,模型可以学习到序列中每个元素对于输出结果的重要性,从而给予不同的权重。

为什么要使用多头呢?下面是我找到的解释:

多头计算可以让模型同时关注输入序列中的不同方面和细节,从而增强模型的表达能力和学习能力。每个注意力头可以捕捉输入序列中的不同模式和关系,而最终的线性变换可以将这些信息融合在一起。
多头计算可以降低模型的复杂度和计算成本。对于较大的 d_model 来说,如果只使用单头计算,那么 QK^T 的结果会非常大,导致 softmax 函数的梯度非常小,不利于网络的训练。而使用多头计算,可以将 d_model 分割成 h 个较小的子空间,从而减少计算量和内存消耗34。
多头计算还可以
提高模型的可解释性和泛化能力
。我们可以从模型中检查不同注意力头的分布,观察模型是如何关注不同位置和信息的。各个注意力头可以学会执行不同的任务,例如语法分析、实体识别等

MultiHeadedAttention类还做了什么事情?
1、通过4个线性层(通常是4)计算得到Q K V矩阵
在transformer中,Q、K、V是通过四个线性层得到的,分别是:
Q = XW^Q ,其中X是embedding输入矩阵,W^Q 是一个可训练的参数矩阵,大小为(d_model* d_model),用于将X映射到Q空间。
K = XW^K ,其中X是embedding输入矩阵,W^K 是一个可训练的参数矩阵,大小为(d_model* d_model),用于将X映射到K空间。
V = XW^V ,其中Xembedding是输入矩阵,W^V 是一个可训练的参数矩阵,大小为(d_model* d_model)用于将X映射到V空间。

这篇关于multiheadattention类原理及源码理解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/325201

相关文章

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟 开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚 第一站:海量资源,应有尽有 走进“智听

hdu4407(容斥原理)

题意:给一串数字1,2,......n,两个操作:1、修改第k个数字,2、查询区间[l,r]中与n互质的数之和。 解题思路:咱一看,像线段树,但是如果用线段树做,那么每个区间一定要记录所有的素因子,这样会超内存。然后我就做不来了。后来看了题解,原来是用容斥原理来做的。还记得这道题目吗?求区间[1,r]中与p互质的数的个数,如果不会的话就先去做那题吧。现在这题是求区间[l,r]中与n互质的数的和

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

如何在Visual Studio中调试.NET源码

今天偶然在看别人代码时,发现在他的代码里使用了Any判断List<T>是否为空。 我一般的做法是先判断是否为null,再判断Count。 看了一下Count的源码如下: 1 [__DynamicallyInvokable]2 public int Count3 {4 [__DynamicallyInvokable]5 get

【C++高阶】C++类型转换全攻略:深入理解并高效应用

📝个人主页🌹:Eternity._ ⏩收录专栏⏪:C++ “ 登神长阶 ” 🤡往期回顾🤡:C++ 智能指针 🌹🌹期待您的关注 🌹🌹 ❀C++的类型转换 📒1. C语言中的类型转换📚2. C++强制类型转换⛰️static_cast🌞reinterpret_cast⭐const_cast🍁dynamic_cast 📜3. C++强制类型转换的原因📝

工厂ERP管理系统实现源码(JAVA)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本,并实时掌握各环节的运营状况。 在采购管理方面,系统能够处理采购订单、供应商管理和采购入库等流程,确保采购过程的透明和高效。仓库管理方面,实现库存的精准管理,包括入库、出库、盘点等操作,确保库存数据的准确性和实时性。 生产管理模块则涵盖了生产计划制定、物料需求计划、

hdu4407容斥原理

题意: 有一个元素为 1~n 的数列{An},有2种操作(1000次): 1、求某段区间 [a,b] 中与 p 互质的数的和。 2、将数列中某个位置元素的值改变。 import java.io.BufferedInputStream;import java.io.BufferedReader;import java.io.IOException;import java.io.Inpu