Transformer学习: Transformer小模块学习--位置编码,多头自注意力,掩码矩阵

本文主要是介绍Transformer学习: Transformer小模块学习--位置编码,多头自注意力,掩码矩阵,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

Transformer学习

  • 1 位置编码模块
    • 1.1 PE代码
    • 1.2 测试PE
  • 2 多头自注意力模块
    • 2.1 多头自注意力代码
    • 2.2 测试多头注意力
  • 3 未来序列掩码矩阵

在这里插入图片描述

1 位置编码模块

P E ( p o s , 2 i ) = sin ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i)=\sin(pos/10000^{2i/d_{\mathrm{model}}}) PE(pos,2i)=sin(pos/100002i/dmodel)

P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i+1)=\cos(pos/10000^{2i/d_\mathrm{model}}) PE(pos,2i+1)=cos(pos/100002i/dmodel)

pos 是序列中每个对象的索引, p o s ∈ [ 0 , m a x s e q l e n ] pos\in [0,max_seq_len] pos[0,maxseqlen], i i i 向量维度序号, i ∈ [ 0 , e m b e d d i m / 2 ] i\in [0,embed_dim/2] i[0,embeddim/2], d m o d e l d_{model} dmodel是模型的embedding维度

1.1 PE代码

import numpy as np
import matplotlib.pyplot as plt
import math
import torch
import seaborn as snsdef get_pos_ecoding(max_seq_len,embed_dim):# 初始化位置矩阵 [max_seq_len,embed_dim]pe = torch.zeros(max_seq_len,embed_dim])position = torch.arange(0,max_seq_len).unsqueeze(1) # [max_seq_len,1]print("位置:", position,position.shape)div_term = torch.exp(torch.arange(0,embed_dim,2)*-(math.log(10000.0)/embed_dim))   # 除项维度为embed_dim的一半,因为对矩阵分奇数和偶数位置进行填充。pe[:,0::2] = torch.sin(position/div_term)pe[:,1::2] = torch.cos(position/div_term)return pe

1.2 测试PE

pe = get_pos_ecoding(8,4)
plt.figure(figsize=(8,8))
sns.heatmap(pe)
plt.title("Sinusoidal Function")
plt.xlabel("hidden dimension")
plt.ylabel("sequence length")

输出:
位置: tensor([[0],
[1],
[2],
[3],
[4],
[5],
[6],
[7]]) torch.Size([8, 1])
除项: tensor([1.0000, 0.0100]) torch.Size([2])
在这里插入图片描述

plt.figure(figsize=(8, 5))
plt.plot(positional_encoding[1:, 1], label="dimension 1")
plt.plot(positional_encoding[1:, 2], label="dimension 2")
plt.plot(positional_encoding[1:, 3], label="dimension 3")
plt.legend()
plt.xlabel("Sequence length")
plt.ylabel("Period of Positional Encoding")

在这里插入图片描述

2 多头自注意力模块

2.1 多头自注意力代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy# 复制网络,即使用几层网络就改变N的数量
# 如 4层线性层  clones(nn.Linear(model_dim,model_dim),4)
def clones(module, N):"Produce N identical layers."return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])# 计算注意力
def attention(q, k, v, mask=None, dropout=None):d_k = q.size(-1)scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim = -1)if dropout is not None:p_attn = dropout(p_attn)return torch.matmul(p_attn, v), p_attn# 计算多头注意力
class Multi_Head_Self_Att(nn.Module):def __init__(self,head,model_dim,dropout=0.1):super(Multi_Head_Self_Att,self).__init__()+assert model_dim % head == 0self.d_k = model_dim/headself.head = headself.linears = clones(nn.Linear(model_dim,model_dim),4)self.att = Noneself.dropout = nn.Dropout(p=dropout)def forward(self,q,k,v,mask=None):if mask is not None:mask = mask.unsqueeze(1)nbatches = q.size(0)# zip函数 将线性层与q,k,v分别对应(self.linears,q),(self.linears,k),(self.linears,v)# q,k,v [bs,-1,head,embed_dim/head]q,k,v = [l(x).view(nbatches,-1,int(self.head),int(self.d_k)).transpose(1,2) for l,x in zip(self.linears,(q,k,v))] # 返回计算注意力之后的值作为x和注意力分数x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)x = x.transpose(1, 2).contiguous().view(nbatches, -1, int(self.head * self.d_k))return self.linears[-1](x),self.attn

2.2 测试多头注意力

# 模型参数
head = 4
model_dim = 128
seq_len = 10
dropout = 0.1# 生成示例输入
q = torch.randn(seq_len, model_dim)
k = torch.randn(seq_len, model_dim)
v = torch.randn(seq_len, model_dim)# 创建多头自注意力模块
att = Multi_Head_Self_Att(head, model_dim, dropout=dropout)# 运行模块
output,att = att(q, k, v)# 输出形状
print("Output shape:", output.shape)
print(att.shape())
sns.heatmap(att.squeeze().detach().cpu())

输出
Output shape: torch.Size([10, 1, 128])
torch.Size([10, 4, 1, 1])
在这里插入图片描述

3 未来序列掩码矩阵

作用: 防止泄露未来要预测的部分,掩码矩阵是一个除对角线的上三角矩阵
3.1 代码

def subsequent_mask(size):"Mask out subsequent positions."attn_shape = (1, size, size)subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')print("掩码矩阵:",subsequent_mask)return torch.from_numpy(subsequent_mask) == 0

测试掩码

plt.figure(figsize=(5,5))
print(subsequent_mask(8),subsequent_mask(8).shape)
plt.imshow(subsequent_mask(8)[0])

掩码矩阵:
[[[0 1 1 1 1 1 1 1]
[0 0 1 1 1 1 1 1]
[0 0 0 1 1 1 1 1]
[0 0 0 0 1 1 1 1]
[0 0 0 0 0 1 1 1]
[0 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 0 0]]]

tensor([[[ True, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True]]]) torch.Size([1, 8, 8])
在这里插入图片描述
紫色部分为添加掩码的部分

这篇关于Transformer学习: Transformer小模块学习--位置编码,多头自注意力,掩码矩阵的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/876662

相关文章

Python中re模块结合正则表达式的实际应用案例

《Python中re模块结合正则表达式的实际应用案例》Python中的re模块是用于处理正则表达式的强大工具,正则表达式是一种用来匹配字符串的模式,它可以在文本中搜索和匹配特定的字符串模式,这篇文章主... 目录前言re模块常用函数一、查看文本中是否包含 A 或 B 字符串二、替换多个关键词为统一格式三、提

springboot项目打jar制作成镜像并指定配置文件位置方式

《springboot项目打jar制作成镜像并指定配置文件位置方式》:本文主要介绍springboot项目打jar制作成镜像并指定配置文件位置方式,具有很好的参考价值,希望对大家有所帮助,如有错误... 目录一、上传jar到服务器二、编写dockerfile三、新建对应配置文件所存放的数据卷目录四、将配置文

python3如何找到字典的下标index、获取list中指定元素的位置索引

《python3如何找到字典的下标index、获取list中指定元素的位置索引》:本文主要介绍python3如何找到字典的下标index、获取list中指定元素的位置索引问题,具有很好的参考价值,... 目录enumerate()找到字典的下标 index获取list中指定元素的位置索引总结enumerat

一文深入详解Python的secrets模块

《一文深入详解Python的secrets模块》在构建涉及用户身份认证、权限管理、加密通信等系统时,开发者最不能忽视的一个问题就是“安全性”,Python在3.6版本中引入了专门面向安全用途的secr... 目录引言一、背景与动机:为什么需要 secrets 模块?二、secrets 模块的核心功能1. 基

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

C/C++中OpenCV 矩阵运算的实现

《C/C++中OpenCV矩阵运算的实现》本文主要介绍了C/C++中OpenCV矩阵运算的实现,包括基本算术运算(标量与矩阵)、矩阵乘法、转置、逆矩阵、行列式、迹、范数等操作,感兴趣的可以了解一下... 目录矩阵的创建与初始化创建矩阵访问矩阵元素基本的算术运算 ➕➖✖️➗矩阵与标量运算矩阵与矩阵运算 (逐元

如何更改pycharm缓存路径和虚拟内存分页文件位置(c盘爆红)

《如何更改pycharm缓存路径和虚拟内存分页文件位置(c盘爆红)》:本文主要介绍如何更改pycharm缓存路径和虚拟内存分页文件位置(c盘爆红)问题,具有很好的参考价值,希望对大家有所帮助,如有... 目录先在你打算存放的地方建四个文件夹更改这四个路径就可以修改默认虚拟内存分页js文件的位置接下来从高级-

PyCharm如何更改缓存位置

《PyCharm如何更改缓存位置》:本文主要介绍PyCharm如何更改缓存位置的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录PyCharm更改缓存位置1.打开PyCharm的安装编程目录2.将config、sjsystem、plugins和log的路径

Python logging模块使用示例详解

《Pythonlogging模块使用示例详解》Python的logging模块是一个灵活且强大的日志记录工具,广泛应用于应用程序的调试、运行监控和问题排查,下面给大家介绍Pythonlogging模... 目录一、为什么使用 logging 模块?二、核心组件三、日志级别四、基本使用步骤五、快速配置(bas