transfomer中attention为什么要除以根号d_k

2024-06-02 01:12

本文主要是介绍transfomer中attention为什么要除以根号d_k,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简介

得到矩阵 Q, K, V之后就可以计算出 Self-Attention 的输出了,计算的公式如下:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d k ) V Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=Softmax(dk QKT)V

好处

除以维度的开方,可以将数据向0方向集中,使得经过softmax后的梯度更大.
从数学上分析,可以使得QK的分布和Q/K保持一致,

推导

对于两个独立的正态分布而言,两者的加法的期望和方差就是两个独立分布的期望和方差。
qk_T的计算过程为[len_q,dim][dim,len_k]=[len_q,len_k],qk的元素等于dim个乘积的和。对于0-1分布表乘积不会影响期望和方差,但是求和操作会使得方差乘以dim,因此对qk元素除以sqrt(dim)把标准差压回1.

这里展示一个不严谨的采样可视化过程
假设在query在(0,1)分布,key在(0,1)分布,随机采样lengthdim个点,然后统计querykey_T的散点的分布

import math
import numpy as np
import matplotlib.pyplot as pltdef plot_curve(mu=0, sigma =1):import numpy as npimport matplotlib.pyplot as pltfrom scipy.stats import norm# 设置正态分布的参数# mu, sigma = 0, 1  # 均值和标准差# 创建一个x值的范围,覆盖正态分布的整个区间x = np.linspace(mu - 4 * sigma, mu + 4 * sigma, 1000)# 计算对应的正态分布的概率密度值y = norm.pdf(x, mu, sigma)# 我们可以选择y值较高的点来绘制散点图,以模拟概率密度的分布# 这里我们可以设置一个阈值,只绘制y值大于某个值的点threshold = 0.01  # 可以根据需要调整这个阈值selected_points = y > thresholdplt.plot(x, y, 'r-', lw=2, label='Normal dist. (mu={}, sigma={})'.format(mu, sigma))plt.title('Normal Distribution Scatter Approximation')plt.xlabel('Value')plt.ylabel('Probability Density')plt.legend()plt.grid(True)plt.show()def plot_poins(x):# 因为这是一个一维的正态分布,我们通常只绘制x轴上的点# 但为了模拟二维散点图,我们可以简单地将y轴设置为与x轴相同或固定值(例如0)y = np.zeros_like(x)# 绘制散点图plt.figure(figsize=(8, 6))plt.scatter(x, y, alpha=0.5)  # alpha控制点的透明度plt.title('Normal (0, 1) Distribution Scatter Plot')plt.xlabel('Value')plt.ylabel('Value (or Frequency if binned)')plt.grid(True)plt.show()if __name__ == '__main__':# 设置随机种子以便结果可复现np.random.seed(0)len = 10000dim = 100query = np.random.normal(0, 1, len*dim).reshape(len,dim)key = np.random.normal(0, 1, len*dim).reshape(dim,len)qk = np.matmul(query,key) / math.sqrt(dim)mean_query = query.mean()std_query = np.std(query,ddof=1)mean_key = key.mean()std_key = np.std(key,ddof=1)mean_qk = qk.mean()std_qk = np.std(qk,ddof=1)plot_poins(query)plot_curve(mean_query,std_query)

在这里插入图片描述

这篇关于transfomer中attention为什么要除以根号d_k的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1022628

相关文章

什么是 Flash Attention

Flash Attention 是 由 Tri Dao 和 Dan Fu 等人在2022年的论文 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 中 提出的, 论文可以从 https://arxiv.org/abs/2205.14135 页面下载,点击 View PDF 就可以下载。 下面我

图神经网络框架DGL实现Graph Attention Network (GAT)笔记

参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 ——基础操作&消息传递 [3]Cora数据集介绍+python读取 一、DGL实现GAT分类机器学习论文 程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3]) 1. 程序 Ubuntu:18.04

时序预测|变分模态分解-双向时域卷积-双向门控单元-注意力机制多变量时间序列预测VMD-BiTCN-BiGRU-Attention

时序预测|变分模态分解-双向时域卷积-双向门控单元-注意力机制多变量时间序列预测VMD-BiTCN-BiGRU-Attention 文章目录 一、基本原理1. 变分模态分解(VMD)2. 双向时域卷积(BiTCN)3. 双向门控单元(BiGRU)4. 注意力机制(Attention)总结流程 二、实验结果三、核心代码四、代码获取五、总结 时序预测|变分模态分解-双向时域卷积

阅读笔记--Guiding Attention in End-to-End Driving Models

作者:Diego Porres1, Yi Xiao1, Gabriel Villalonga1, Alexandre Levy1, Antonio M. L ́ opez1,2 出版时间:arXiv:2405.00242v1 [cs.CV] 30 Apr 2024 这篇论文研究了如何引导基于视觉的端到端自动驾驶模型的注意力,以提高它们的驾驶质量和获得更直观的激活图。 摘 要   介绍

基于 BiLSTM+Attention 实现降雨预测多变量时序分类——明日是否降雨

前言 系列专栏:【深度学习:算法项目实战】✨︎ 涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记忆、自然语言处理、深度强化学习、大型语言模型和迁移学习。 降雨预测作为气象学和水文学领域的重要研究课题,‌对于农业、‌城市规划、

Show,Attend and Tell: Neural Image Caption Generation with Visual Attention

简单的翻译阅读了一下 Abstract 受机器翻译和对象检测领域最新工作的启发,我们引入了一种基于注意力的模型,该模型可以自动学习描述图像的内容。我们描述了如何使用标准的反向传播技术,以确定性的方式训练模型,并通过最大化变分下界随机地训练模型。我们还通过可视化展示了模型如何能够自动学习将注视固定在显着对象上,同时在输出序列中生成相应的单词。我们通过三个基准数据集(Flickr9k,Flickr

深入理解推荐系统:推荐系统中的attention机制

什么是attention机制、在推荐模型中的应用(会介绍相关模型,AFM/DIN/DIEN/DST)和参考文献  什么是attention机制  Attention函数的本质可以被描述为一个查询(query)到一系列(键key-值value)对的映射,在计算attention时主要分为三步 第一步是将query和每个key进行相似度计算得到权重,常用的相似度函数有点积,拼接,感知机等;

高精度加、减、乘、除(高精除以低精)

高精度加法 法1:P1601 A+B Problem(高精) #include <bits/stdc++.h>using namespace std;char s1[510], s2[510];int a[510], b[510], sum[510];int lena, lenb, lens;int main(){cin >> s1 >> s2;lena=strlen(s1);le

PTA L1-037 A除以B

L1-037 A除以B(10分) 真的是简单题哈 —— 给定两个绝对值不超过100的整数A和B,要求你按照“A/B=商”的格式输出结果。 输入格式: 输入在第一行给出两个整数A和B(−100≤A,B≤100),数字间以空格分隔。 输出格式: 在一行中输出结果:如果分母是正数,则输出“A/B=商”;如果分母是负数,则要用括号把分母括起来输出;如果分母为零,则输出的商应为Error。输出的商

注意力机制(Attention mechanism)(中篇)

模型的输入是一组向量,它可以是文字,可以是语音,可以是图。而输出有三种可能性, 第一种可能性是每一个向量都有一个对应的标签。如图1所示,当模型看到输入是4个向 量的时候,它就要输出4个标签。如果是回归问题,每个标签是一个数值。如果是分类问题, 每个标签是一个类别。但是在类型1的问题里面,输入跟输出的长度是一样的。模型不需要 去烦恼要输出多少的标签,输出多少的标量。反正输入是4个向量,输出就是4个标