torch.einsum详解

2024-08-20 23:44
文章标签 torch einsum 详解

本文主要是介绍torch.einsum详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

torch.einsum 是 PyTorch 中用于执行高效张量运算的函数,基于爱因斯坦求和约定(Einstein summation convention)。它能够处理复杂的张量操作,并简化代码书写。

基本语法

torch.einsum(subscripts, *operands)
  • subscripts:一个字符串,用于描述输入张量的维度如何结合。
  • *operands:待操作的张量。

爱因斯坦求和约定

爱因斯坦求和约定是一个简化张量运算的方式,省略了显式的求和符号。通过指定各维度的标签,可以直接描述复杂的张量运算。

语法结构

  • "nqhd,nkhd->nhqk": 这个字符串描述了如何对两个张量进行操作,并生成输出张量的维度。

    • n:批次大小(batch size)
    • q:查询序列长度(query length)
    • k:键序列长度(key length)
    • h:注意力头的数量(number of heads)
    • d:每个注意力头的维度(dimension per head)

示例代码

以下是使用 torch.einsum 计算多头注意力机制中点积相似性的示例代码:

import torch# 定义多头注意力机制的点积计算函数
def compute_attention_scores(queries, keys):# 计算点积相似性分数energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])return energy# 示例数据
N = 1            # 批次大小
q = 2            # 查询序列长度
k = 3            # 键序列长度
h = 2            # 注意力头数量
d = 4            # 每个注意力头的维度# 随机生成 queries 和 keys
queries = torch.rand((N, q, h, d))  # Shape (1, 2, 2, 4)
keys = torch.rand((N, k, h, d))    # Shape (1, 3, 2, 4)# 计算注意力分数
energy = compute_attention_scores(queries, keys)print("Energy shape:", energy.shape)
print(energy)

计算过程

  1. 维度解释

    • queries 的维度为 (1, 2, 2, 4)N = 1(批次大小),q = 2(查询序列长度),h = 2(注意力头数量),d = 4(每个头部的维度)。
    • keys 的维度为 (1, 3, 2, 4)N = 1(批次大小),k = 3(键序列长度),h = 2(注意力头数量),d = 4(每个头部的维度)。
  2. 点积计算

    • 对每个批次和每个头部,计算 querieskeysd 维度上的点积。
    • 结果的维度为 (N, h, q, k),其中:
      • N 是批次大小
      • h 是注意力头的数量
      • q 是查询序列的长度
      • k 是键序列的长度

    点积计算的实际操作是:

    • 对于每个批次(n)和每个头部(h),对 querieskeys 张量在 d 维度上进行点积运算,得到形状为 (q, k) 的张量。

简单计算示例

假设我们有如下示例数据:

queries = torch.tensor([[[[1.0, 0.5, 0.2, 1.5], [0.3, 0.7, 0.6, 0.8]], [[0.9, 0.4, 1.2, 0.5], [0.2, 0.6, 0.8, 0.7]]]])
keys = torch.tensor([[[[0.1, 1.0, 0.3, 0.5], [0.2, 0.4, 0.6, 0.7], [0.8, 1.0, 0.9, 0.5]], [[0.1, 0.5, 0.2, 0.8], [0.3, 0.4, 0.7, 0.9], [0.6, 0.8, 1.0, 0.2]]]])

点积计算

  • 对于第一个批次和第一个头部:

    • queries[0, :, 0, :]keys[0, :, 0, :] 的点积计算如下:

    计算:

    energy[0, 0, 0, 0] = (1.0*0.1 + 0.5*1.0 + 0.2*0.3 + 1.5*0.5) = 0.1 + 0.5 + 0.06 + 0.75 = 1.41
    energy[0, 0, 0, 1] = (1.0*0.2 + 0.5*0.4 + 0.2*0.6 + 1.5*0.7) = 0.2 + 0.2 + 0.12 + 1.05 = 1.59
    energy[0, 0, 0, 2] = (1.0*0.8 + 0.5*1.0 + 0.2*0.9 + 1.5*0.5) = 0.8 + 0.5 + 0.18 + 0.75 = 1.23
    energy[0, 0, 1, 0] = (0.3*0.1 + 0.7*1.0 + 0.6*0.3 + 0.8*0.5) = 0.03 + 0.7 + 0.18 + 0.4 = 1.31
    energy[0, 0, 1, 1] = (0.3*0.2 + 0.7*0.4 + 0.6*0.6 + 0.8*0.7) = 0.06 + 0.28 + 0.36 + 0.56 = 1.26
    energy[0, 0, 1, 2] = (0.3*0.8 + 0.7*1.0 + 0.6*0.9 + 0.8*0.5) = 0.24 + 0.7 + 0.54 + 0.4 = 1.88
    

总结

torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) 用于计算 querieskeys 张量在注意力机制中的点积,相似性得分。它通过爱因斯坦求和约定指定了如何在多维张量上执行这些操作,使得代码更简洁、效率更高。

Code

AI_With_NumPy
此项目汇集了很多AI相关的代码实现,供大家学习使用,欢迎点赞收藏👏🏻

备注

个人水平有限,有问题随时交流~

这篇关于torch.einsum详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1091516

相关文章

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof

6.1.数据结构-c/c++堆详解下篇(堆排序,TopK问题)

上篇:6.1.数据结构-c/c++模拟实现堆上篇(向下,上调整算法,建堆,增删数据)-CSDN博客 本章重点 1.使用堆来完成堆排序 2.使用堆解决TopK问题 目录 一.堆排序 1.1 思路 1.2 代码 1.3 简单测试 二.TopK问题 2.1 思路(求最小): 2.2 C语言代码(手写堆) 2.3 C++代码(使用优先级队列 priority_queue)

K8S(Kubernetes)开源的容器编排平台安装步骤详解

K8S(Kubernetes)是一个开源的容器编排平台,用于自动化部署、扩展和管理容器化应用程序。以下是K8S容器编排平台的安装步骤、使用方式及特点的概述: 安装步骤: 安装Docker:K8S需要基于Docker来运行容器化应用程序。首先要在所有节点上安装Docker引擎。 安装Kubernetes Master:在集群中选择一台主机作为Master节点,安装K8S的控制平面组件,如AP

嵌入式Openharmony系统构建与启动详解

大家好,今天主要给大家分享一下,如何构建Openharmony子系统以及系统的启动过程分解。 第一:OpenHarmony系统构建      首先熟悉一下,构建系统是一种自动化处理工具的集合,通过将源代码文件进行一系列处理,最终生成和用户可以使用的目标文件。这里的目标文件包括静态链接库文件、动态链接库文件、可执行文件、脚本文件、配置文件等。      我们在编写hellowor

LabVIEW FIFO详解

在LabVIEW的FPGA开发中,FIFO(先入先出队列)是常用的数据传输机制。通过配置FIFO的属性,工程师可以在FPGA和主机之间,或不同FPGA VIs之间进行高效的数据传输。根据具体需求,FIFO有多种类型与实现方式,包括目标范围内FIFO(Target-Scoped)、DMA FIFO以及点对点流(Peer-to-Peer)。 FIFO类型 **目标范围FIFO(Target-Sc

019、JOptionPane类的常用静态方法详解

目录 JOptionPane类的常用静态方法详解 1. showInputDialog()方法 1.1基本用法 1.2带有默认值的输入框 1.3带有选项的输入对话框 1.4自定义图标的输入对话框 2. showConfirmDialog()方法 2.1基本用法 2.2自定义按钮和图标 2.3带有自定义组件的确认对话框 3. showMessageDialog()方法 3.1

脏页的标记方式详解

脏页的标记方式 一、引言 在数据库系统中,脏页是指那些被修改过但还未写入磁盘的数据页。为了有效地管理这些脏页并确保数据的一致性,数据库需要对脏页进行标记。了解脏页的标记方式对于理解数据库的内部工作机制和优化性能至关重要。 二、脏页产生的过程 当数据库中的数据被修改时,这些修改首先会在内存中的缓冲池(Buffer Pool)中进行。例如,执行一条 UPDATE 语句修改了某一行数据,对应的缓

OmniGlue论文详解(特征匹配)

OmniGlue论文详解(特征匹配) 摘要1. 引言2. 相关工作2.1. 广义局部特征匹配2.2. 稀疏可学习匹配2.3. 半稠密可学习匹配2.4. 与其他图像表示匹配 3. OmniGlue3.1. 模型概述3.2. OmniGlue 细节3.2.1. 特征提取3.2.2. 利用DINOv2构建图形。3.2.3. 信息传播与新的指导3.2.4. 匹配层和损失函数3.2.5. 与Super

web群集--nginx配置文件location匹配符的优先级顺序详解及验证

文章目录 前言优先级顺序优先级顺序(详解)1. 精确匹配(Exact Match)2. 正则表达式匹配(Regex Match)3. 前缀匹配(Prefix Match) 匹配规则的综合应用验证优先级 前言 location的作用 在 NGINX 中,location 指令用于定义如何处理特定的请求 URI。由于网站往往需要不同的处理方式来适应各种请求,NGINX 提供了多种匹