Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络

本文主要是介绍Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文是对Transfomer重要模块的源码解析,完整笔记链接点这里!

缩放点积自注意力 (Scaled Dot-Product Attention)

缩放点积自注意力是一种自注意力机制,它通过查询(Query)、键(Key)和值(Value)的关系来计算注意力权重。该机制的核心在于先计算查询和所有键的点积,然后进行缩放处理,应用softmax函数得到最终的注意力权重,最后用这些权重对值进行加权求和。

源码解析:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ScaledDotProductAttention(nn.Module):''' Scaled Dot-Product Attention '''def __init__(self, temperature, attn_dropout=0.1):super().__init__()self.temperature = temperature  # 温度参数,用于缩放点积self.dropout = nn.Dropout(attn_dropout)  # Dropout层def forward(self, q, k, v, mask=None):attn = torch.matmul(q / self.temperature, k.transpose(2, 3))  # 计算缩放后的点积if mask is not None:attn = attn.masked_fill(mask == 0, -1e9)  # 掩码操作,将需要忽略的位置设置为一个非常小的值attn = self.dropout(F.softmax(attn, dim=-1))  # 应用softmax函数并进行dropoutoutput = torch.matmul(attn, v)  # 使用注意力权重对值(v)进行加权求和return output, attn
  • __init__ 方法中的 temperature 参数用于缩放点积,通常设置为键(Key)维度的平方根。attn_dropout 是在应用softmax函数后进行dropout的比例。
  • forward 方法计算缩放点积自注意力。首先,它计算查询(q)和键(k)的点积,并通过除以 temperature 进行缩放。如果提供了 mask,则会使用 masked_fill 将掩码位置的注意力权重设为一个非常小的负数(这里是 -1e9),使得softmax后这些位置的权重接近于0。之后,应用dropout和softmax函数得到最终的注意力权重。最后,使用这些权重对值(v)进行加权求和得到输出。

多头注意力 (Multi-Head Attention)

多头注意力通过将输入分割成多个头,让每个头在不同的子空间表示上计算注意力,然后将这些头的输出合并。这样做可以让模型在多个子空间中捕获丰富的信息。

源码解析:
import torch.nn as nn
import torch.nn.functional as F
from transformer.Modules import ScaledDotProductAttentionclass MultiHeadAttention(nn.Module):''' Multi-Head Attention module '''def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):super().__init__()self.n_head = n_head  # 头的数量self.d_k = d_k  # 键/查询的维度self.d_v = d_v  # 值的维度self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)  # 查询的线性变换self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)  # 键的线性变换self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)  # 值的线性变换self.fc = nn.Linear(n_head * d_v, d_model, bias=False)  # 输出的线性变换self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)  # 缩放点积注意力模块self.dropout = nn.Dropout(dropout)  # Dropout层self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)  # 层归一化def forward(self, q, k, v, mask=None):# 保存输入以便后面进行残差连接residual = q# 线性变换并重塑以准备多头计算q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)# 转置以将头维度提前,便于并行计算q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)# 如果存在掩码,则扩展掩码以适应头维度if mask is not None:mask = mask.unsqueeze(1)   # 为头维度广播掩码# 调用缩放点积注意力模块q, attn = self.attention(q, k, v, mask=mask)# 转置并重塑以合并多头q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)# 应用线性变换和dropoutq = self.dropout(self.fc(q))# 添加残差连接并进行层归一化q += residualq = self.layer_norm(q)# 返回多头注意力的输出和注意力权重return q, attn
  • __init__ 方法初始化了多头注意力的参数,包括头的数量 n_head,查询/键/值的维度 d_kd_v,以及线性层 w_qsw_ksw_vsfc
  • forward 方法首先将输入 qkv 通过线性层映射到多头的维度,然后重塑并转置以便进行并行计算。如果存在掩码,它会被扩展以适应头维度。调用缩放点积注意力模块计算注意力,之后合并多头输出,并应用线性变换和dropout。最后,添加残差连接和层归一化。

前馈网络 (Positionwise FeedForward)

前馈网络(FFN)在自注意力层之后应用,用于进行非线性变换,增加模型的复杂度和表达能力。

源码解析:
import torch.nn as nn
import torch.nn.functional as Fclass PositionwiseFeedForward(nn.Module):''' A two-feed-forward-layer module '''def __init__(self, d_in, d_hid, dropout=0.1):super().__init__()self.w_1 = nn.Linear(d_in, d_hid)  # 第一个线性层self.w_2 = nn.Linear(d_hid, d_in)  # 第二个线性层self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)  # 层归一化self.dropout = nn.Dropout(dropout)  # Dropout层def forward(self, x):# 保存输入以便后面进行残差连接residual = x# 通过第一个线性层,然后应用ReLU激活函数x = self.w_1(x)x = F.relu(x)# 通过第二个线性层x = self.w_2(x)# 应用dropoutx = self.dropout(x)# 添加残差连接并进行层归一化x += residualx = self.layer_norm(x)# 返回输出return x
  • __init__ 方法初始化了两个线性层 w_1w_2,层归一化 layer_norm,以及dropout层。
  • forward 方法首先通过第一个线性层和ReLU激活函数,然后通过第二个线性层。应用dropout层后,添加残差连接并进行层归一化。

这篇关于Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/541623

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟 开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚 第一站:海量资源,应有尽有 走进“智听

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

Linux 网络编程 --- 应用层

一、自定义协议和序列化反序列化 代码: 序列化反序列化实现网络版本计算器 二、HTTP协议 1、谈两个简单的预备知识 https://www.baidu.com/ --- 域名 --- 域名解析 --- IP地址 http的端口号为80端口,https的端口号为443 url为统一资源定位符。CSDNhttps://mp.csdn.net/mp_blog/creation/editor

【测试】输入正确用户名和密码,点击登录没有响应的可能性原因

目录 一、前端问题 1. 界面交互问题 2. 输入数据校验问题 二、网络问题 1. 网络连接中断 2. 代理设置问题 三、后端问题 1. 服务器故障 2. 数据库问题 3. 权限问题: 四、其他问题 1. 缓存问题 2. 第三方服务问题 3. 配置问题 一、前端问题 1. 界面交互问题 登录按钮的点击事件未正确绑定,导致点击后无法触发登录操作。 页面可能存在

ASIO网络调试助手之一:简介

多年前,写过几篇《Boost.Asio C++网络编程》的学习文章,一直没机会实践。最近项目中用到了Asio,于是抽空写了个网络调试助手。 开发环境: Win10 Qt5.12.6 + Asio(standalone) + spdlog 支持协议: UDP + TCP Client + TCP Server 独立的Asio(http://www.think-async.com)只包含了头文件,不依

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

如何在Visual Studio中调试.NET源码

今天偶然在看别人代码时,发现在他的代码里使用了Any判断List<T>是否为空。 我一般的做法是先判断是否为null,再判断Count。 看了一下Count的源码如下: 1 [__DynamicallyInvokable]2 public int Count3 {4 [__DynamicallyInvokable]5 get

poj 3181 网络流,建图。

题意: 农夫约翰为他的牛准备了F种食物和D种饮料。 每头牛都有各自喜欢的食物和饮料,而每种食物和饮料都只能分配给一头牛。 问最多能有多少头牛可以同时得到喜欢的食物和饮料。 解析: 由于要同时得到喜欢的食物和饮料,所以网络流建图的时候要把牛拆点了。 如下建图: s -> 食物 -> 牛1 -> 牛2 -> 饮料 -> t 所以分配一下点: s  =  0, 牛1= 1~

poj 3068 有流量限制的最小费用网络流

题意: m条有向边连接了n个仓库,每条边都有一定费用。 将两种危险品从0运到n-1,除了起点和终点外,危险品不能放在一起,也不能走相同的路径。 求最小的费用是多少。 解析: 抽象出一个源点s一个汇点t,源点与0相连,费用为0,容量为2。 汇点与n - 1相连,费用为0,容量为2。 每条边之间也相连,费用为每条边的费用,容量为1。 建图完毕之后,求一条流量为2的最小费用流就行了