Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络

本文主要是介绍Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文是对Transfomer重要模块的源码解析,完整笔记链接点这里!

缩放点积自注意力 (Scaled Dot-Product Attention)

缩放点积自注意力是一种自注意力机制,它通过查询(Query)、键(Key)和值(Value)的关系来计算注意力权重。该机制的核心在于先计算查询和所有键的点积,然后进行缩放处理,应用softmax函数得到最终的注意力权重,最后用这些权重对值进行加权求和。

源码解析:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ScaledDotProductAttention(nn.Module):''' Scaled Dot-Product Attention '''def __init__(self, temperature, attn_dropout=0.1):super().__init__()self.temperature = temperature  # 温度参数,用于缩放点积self.dropout = nn.Dropout(attn_dropout)  # Dropout层def forward(self, q, k, v, mask=None):attn = torch.matmul(q / self.temperature, k.transpose(2, 3))  # 计算缩放后的点积if mask is not None:attn = attn.masked_fill(mask == 0, -1e9)  # 掩码操作,将需要忽略的位置设置为一个非常小的值attn = self.dropout(F.softmax(attn, dim=-1))  # 应用softmax函数并进行dropoutoutput = torch.matmul(attn, v)  # 使用注意力权重对值(v)进行加权求和return output, attn
  • __init__ 方法中的 temperature 参数用于缩放点积,通常设置为键(Key)维度的平方根。attn_dropout 是在应用softmax函数后进行dropout的比例。
  • forward 方法计算缩放点积自注意力。首先,它计算查询(q)和键(k)的点积,并通过除以 temperature 进行缩放。如果提供了 mask,则会使用 masked_fill 将掩码位置的注意力权重设为一个非常小的负数(这里是 -1e9),使得softmax后这些位置的权重接近于0。之后,应用dropout和softmax函数得到最终的注意力权重。最后,使用这些权重对值(v)进行加权求和得到输出。

多头注意力 (Multi-Head Attention)

多头注意力通过将输入分割成多个头,让每个头在不同的子空间表示上计算注意力,然后将这些头的输出合并。这样做可以让模型在多个子空间中捕获丰富的信息。

源码解析:
import torch.nn as nn
import torch.nn.functional as F
from transformer.Modules import ScaledDotProductAttentionclass MultiHeadAttention(nn.Module):''' Multi-Head Attention module '''def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):super().__init__()self.n_head = n_head  # 头的数量self.d_k = d_k  # 键/查询的维度self.d_v = d_v  # 值的维度self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)  # 查询的线性变换self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)  # 键的线性变换self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)  # 值的线性变换self.fc = nn.Linear(n_head * d_v, d_model, bias=False)  # 输出的线性变换self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)  # 缩放点积注意力模块self.dropout = nn.Dropout(dropout)  # Dropout层self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)  # 层归一化def forward(self, q, k, v, mask=None):# 保存输入以便后面进行残差连接residual = q# 线性变换并重塑以准备多头计算q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)# 转置以将头维度提前,便于并行计算q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)# 如果存在掩码,则扩展掩码以适应头维度if mask is not None:mask = mask.unsqueeze(1)   # 为头维度广播掩码# 调用缩放点积注意力模块q, attn = self.attention(q, k, v, mask=mask)# 转置并重塑以合并多头q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)# 应用线性变换和dropoutq = self.dropout(self.fc(q))# 添加残差连接并进行层归一化q += residualq = self.layer_norm(q)# 返回多头注意力的输出和注意力权重return q, attn
  • __init__ 方法初始化了多头注意力的参数,包括头的数量 n_head,查询/键/值的维度 d_kd_v,以及线性层 w_qsw_ksw_vsfc
  • forward 方法首先将输入 qkv 通过线性层映射到多头的维度,然后重塑并转置以便进行并行计算。如果存在掩码,它会被扩展以适应头维度。调用缩放点积注意力模块计算注意力,之后合并多头输出,并应用线性变换和dropout。最后,添加残差连接和层归一化。

前馈网络 (Positionwise FeedForward)

前馈网络(FFN)在自注意力层之后应用,用于进行非线性变换,增加模型的复杂度和表达能力。

源码解析:
import torch.nn as nn
import torch.nn.functional as Fclass PositionwiseFeedForward(nn.Module):''' A two-feed-forward-layer module '''def __init__(self, d_in, d_hid, dropout=0.1):super().__init__()self.w_1 = nn.Linear(d_in, d_hid)  # 第一个线性层self.w_2 = nn.Linear(d_hid, d_in)  # 第二个线性层self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)  # 层归一化self.dropout = nn.Dropout(dropout)  # Dropout层def forward(self, x):# 保存输入以便后面进行残差连接residual = x# 通过第一个线性层,然后应用ReLU激活函数x = self.w_1(x)x = F.relu(x)# 通过第二个线性层x = self.w_2(x)# 应用dropoutx = self.dropout(x)# 添加残差连接并进行层归一化x += residualx = self.layer_norm(x)# 返回输出return x
  • __init__ 方法初始化了两个线性层 w_1w_2,层归一化 layer_norm,以及dropout层。
  • forward 方法首先通过第一个线性层和ReLU激活函数,然后通过第二个线性层。应用dropout层后,添加残差连接并进行层归一化。

这篇关于Transfomer重要源码解析:缩放点击注意力,多头自注意力,前馈网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/541623

相关文章

PostgreSQL的扩展dict_int应用案例解析

《PostgreSQL的扩展dict_int应用案例解析》dict_int扩展为PostgreSQL提供了专业的整数文本处理能力,特别适合需要精确处理数字内容的搜索场景,本文给大家介绍PostgreS... 目录PostgreSQL的扩展dict_int一、扩展概述二、核心功能三、安装与启用四、字典配置方法

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

Java中的雪花算法Snowflake解析与实践技巧

《Java中的雪花算法Snowflake解析与实践技巧》本文解析了雪花算法的原理、Java实现及生产实践,涵盖ID结构、位运算技巧、时钟回拨处理、WorkerId分配等关键点,并探讨了百度UidGen... 目录一、雪花算法核心原理1.1 算法起源1.2 ID结构详解1.3 核心特性二、Java实现解析2.

使用Python绘制3D堆叠条形图全解析

《使用Python绘制3D堆叠条形图全解析》在数据可视化的工具箱里,3D图表总能带来眼前一亮的效果,本文就来和大家聊聊如何使用Python实现绘制3D堆叠条形图,感兴趣的小伙伴可以了解下... 目录为什么选择 3D 堆叠条形图代码实现:从数据到 3D 世界的搭建核心代码逐行解析细节优化应用场景:3D 堆叠图

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧

解析C++11 static_assert及与Boost库的关联从入门到精通

《解析C++11static_assert及与Boost库的关联从入门到精通》static_assert是C++中强大的编译时验证工具,它能够在编译阶段拦截不符合预期的类型或值,增强代码的健壮性,通... 目录一、背景知识:传统断言方法的局限性1.1 assert宏1.2 #error指令1.3 第三方解决

全面解析MySQL索引长度限制问题与解决方案

《全面解析MySQL索引长度限制问题与解决方案》MySQL对索引长度设限是为了保持高效的数据检索性能,这个限制不是MySQL的缺陷,而是数据库设计中的权衡结果,下面我们就来看看如何解决这一问题吧... 目录引言:为什么会有索引键长度问题?一、问题根源深度解析mysql索引长度限制原理实际场景示例二、五大解决

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现