torch.einsum详解

2024-08-20 23:44
文章标签 详解 torch einsum

本文主要是介绍torch.einsum详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

torch.einsum 是 PyTorch 中用于执行高效张量运算的函数,基于爱因斯坦求和约定(Einstein summation convention)。它能够处理复杂的张量操作,并简化代码书写。

基本语法

torch.einsum(subscripts, *operands)
  • subscripts:一个字符串,用于描述输入张量的维度如何结合。
  • *operands:待操作的张量。

爱因斯坦求和约定

爱因斯坦求和约定是一个简化张量运算的方式,省略了显式的求和符号。通过指定各维度的标签,可以直接描述复杂的张量运算。

语法结构

  • "nqhd,nkhd->nhqk": 这个字符串描述了如何对两个张量进行操作,并生成输出张量的维度。

    • n:批次大小(batch size)
    • q:查询序列长度(query length)
    • k:键序列长度(key length)
    • h:注意力头的数量(number of heads)
    • d:每个注意力头的维度(dimension per head)

示例代码

以下是使用 torch.einsum 计算多头注意力机制中点积相似性的示例代码:

import torch# 定义多头注意力机制的点积计算函数
def compute_attention_scores(queries, keys):# 计算点积相似性分数energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])return energy# 示例数据
N = 1            # 批次大小
q = 2            # 查询序列长度
k = 3            # 键序列长度
h = 2            # 注意力头数量
d = 4            # 每个注意力头的维度# 随机生成 queries 和 keys
queries = torch.rand((N, q, h, d))  # Shape (1, 2, 2, 4)
keys = torch.rand((N, k, h, d))    # Shape (1, 3, 2, 4)# 计算注意力分数
energy = compute_attention_scores(queries, keys)print("Energy shape:", energy.shape)
print(energy)

计算过程

  1. 维度解释

    • queries 的维度为 (1, 2, 2, 4)N = 1(批次大小),q = 2(查询序列长度),h = 2(注意力头数量),d = 4(每个头部的维度)。
    • keys 的维度为 (1, 3, 2, 4)N = 1(批次大小),k = 3(键序列长度),h = 2(注意力头数量),d = 4(每个头部的维度)。
  2. 点积计算

    • 对每个批次和每个头部,计算 querieskeysd 维度上的点积。
    • 结果的维度为 (N, h, q, k),其中:
      • N 是批次大小
      • h 是注意力头的数量
      • q 是查询序列的长度
      • k 是键序列的长度

    点积计算的实际操作是:

    • 对于每个批次(n)和每个头部(h),对 querieskeys 张量在 d 维度上进行点积运算,得到形状为 (q, k) 的张量。

简单计算示例

假设我们有如下示例数据:

queries = torch.tensor([[[[1.0, 0.5, 0.2, 1.5], [0.3, 0.7, 0.6, 0.8]], [[0.9, 0.4, 1.2, 0.5], [0.2, 0.6, 0.8, 0.7]]]])
keys = torch.tensor([[[[0.1, 1.0, 0.3, 0.5], [0.2, 0.4, 0.6, 0.7], [0.8, 1.0, 0.9, 0.5]], [[0.1, 0.5, 0.2, 0.8], [0.3, 0.4, 0.7, 0.9], [0.6, 0.8, 1.0, 0.2]]]])

点积计算

  • 对于第一个批次和第一个头部:

    • queries[0, :, 0, :]keys[0, :, 0, :] 的点积计算如下:

    计算:

    energy[0, 0, 0, 0] = (1.0*0.1 + 0.5*1.0 + 0.2*0.3 + 1.5*0.5) = 0.1 + 0.5 + 0.06 + 0.75 = 1.41
    energy[0, 0, 0, 1] = (1.0*0.2 + 0.5*0.4 + 0.2*0.6 + 1.5*0.7) = 0.2 + 0.2 + 0.12 + 1.05 = 1.59
    energy[0, 0, 0, 2] = (1.0*0.8 + 0.5*1.0 + 0.2*0.9 + 1.5*0.5) = 0.8 + 0.5 + 0.18 + 0.75 = 1.23
    energy[0, 0, 1, 0] = (0.3*0.1 + 0.7*1.0 + 0.6*0.3 + 0.8*0.5) = 0.03 + 0.7 + 0.18 + 0.4 = 1.31
    energy[0, 0, 1, 1] = (0.3*0.2 + 0.7*0.4 + 0.6*0.6 + 0.8*0.7) = 0.06 + 0.28 + 0.36 + 0.56 = 1.26
    energy[0, 0, 1, 2] = (0.3*0.8 + 0.7*1.0 + 0.6*0.9 + 0.8*0.5) = 0.24 + 0.7 + 0.54 + 0.4 = 1.88
    

总结

torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) 用于计算 querieskeys 张量在注意力机制中的点积,相似性得分。它通过爱因斯坦求和约定指定了如何在多维张量上执行这些操作,使得代码更简洁、效率更高。

Code

AI_With_NumPy
此项目汇集了很多AI相关的代码实现,供大家学习使用,欢迎点赞收藏👏🏻

备注

个人水平有限,有问题随时交流~

这篇关于torch.einsum详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1091516

相关文章

SpringBoot整合mybatisPlus实现批量插入并获取ID详解

《SpringBoot整合mybatisPlus实现批量插入并获取ID详解》这篇文章主要为大家详细介绍了SpringBoot如何整合mybatisPlus实现批量插入并获取ID,文中的示例代码讲解详细... 目录【1】saveBATch(一万条数据总耗时:2478ms)【2】集合方式foreach(一万条数

Python装饰器之类装饰器详解

《Python装饰器之类装饰器详解》本文将详细介绍Python中类装饰器的概念、使用方法以及应用场景,并通过一个综合详细的例子展示如何使用类装饰器,希望对大家有所帮助,如有错误或未考虑完全的地方,望不... 目录1. 引言2. 装饰器的基本概念2.1. 函数装饰器复习2.2 类装饰器的定义和使用3. 类装饰

MySQL 中的 JSON 查询案例详解

《MySQL中的JSON查询案例详解》:本文主要介绍MySQL的JSON查询的相关知识,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录mysql 的 jsON 路径格式基本结构路径组件详解特殊语法元素实际示例简单路径复杂路径简写操作符注意MySQL 的 J

Python ZIP文件操作技巧详解

《PythonZIP文件操作技巧详解》在数据处理和系统开发中,ZIP文件操作是开发者必须掌握的核心技能,Python标准库提供的zipfile模块以简洁的API和跨平台特性,成为处理ZIP文件的首选... 目录一、ZIP文件操作基础三板斧1.1 创建压缩包1.2 解压操作1.3 文件遍历与信息获取二、进阶技

一文详解Java异常处理你都了解哪些知识

《一文详解Java异常处理你都了解哪些知识》:本文主要介绍Java异常处理的相关资料,包括异常的分类、捕获和处理异常的语法、常见的异常类型以及自定义异常的实现,文中通过代码介绍的非常详细,需要的朋... 目录前言一、什么是异常二、异常的分类2.1 受检异常2.2 非受检异常三、异常处理的语法3.1 try-

Java中的@SneakyThrows注解用法详解

《Java中的@SneakyThrows注解用法详解》:本文主要介绍Java中的@SneakyThrows注解用法的相关资料,Lombok的@SneakyThrows注解简化了Java方法中的异常... 目录前言一、@SneakyThrows 简介1.1 什么是 Lombok?二、@SneakyThrows

Java中字符串转时间与时间转字符串的操作详解

《Java中字符串转时间与时间转字符串的操作详解》Java的java.time包提供了强大的日期和时间处理功能,通过DateTimeFormatter可以轻松地在日期时间对象和字符串之间进行转换,下面... 目录一、字符串转时间(一)使用预定义格式(二)自定义格式二、时间转字符串(一)使用预定义格式(二)自

Redis Pipeline(管道) 详解

《RedisPipeline(管道)详解》Pipeline管道是Redis提供的一种批量执行命令的机制,通过将多个命令一次性发送到服务器并统一接收响应,减少网络往返次数(RTT),显著提升执行效率... 目录Redis Pipeline 详解1. Pipeline 的核心概念2. 工作原理与性能提升3. 核

Python正则表达式语法及re模块中的常用函数详解

《Python正则表达式语法及re模块中的常用函数详解》这篇文章主要给大家介绍了关于Python正则表达式语法及re模块中常用函数的相关资料,正则表达式是一种强大的字符串处理工具,可以用于匹配、切分、... 目录概念、作用和步骤语法re模块中的常用函数总结 概念、作用和步骤概念: 本身也是一个字符串,其中

Nginx location匹配模式与规则详解

《Nginxlocation匹配模式与规则详解》:本文主要介绍Nginxlocation匹配模式与规则,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、环境二、匹配模式1. 精准模式2. 前缀模式(不继续匹配正则)3. 前缀模式(继续匹配正则)4. 正则模式(大