torch.einsum 爱因斯坦求和约定

2024-08-26 02:28

本文主要是介绍torch.einsum 爱因斯坦求和约定,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

torch.einsum是一个强大的函数,用于执行爱因斯坦求和约定(Einstein summation convention)。它可以简洁地表达复杂的张量运算。

  1. 对于 l_pos = torch.einsum('nc,nc->n', [q, k])

    • ‘nc,nc->n’ 是一个表示运算规则的字符串。
    • ‘nc’ 表示一个形状为 (N, C) 的张量,N 是批次大小,C 是特征维度。
    • 这个操作等同于矩阵乘法后的对角线元素,或者说是每对向量的点积。

    示例:

    q = torch.tensor([[1, 2], [3, 4]])
    k = torch.tensor([[5, 6], [7, 8]])
    result = torch.einsum('nc,nc->n', [q, k])
    # 等价于 
    # result = torch.sum(q * k, dim=1)
    # 结果: tensor([17, 53])
    
  2. 对于 l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

    • ‘nc,ck->nk’ 表示两个矩阵的乘法。
    • ‘nc’ 是形状为 (N, C) 的查询张量。
    • ‘ck’ 是形状为 (C, K) 的队列张量,K 是队列长度。
    • 结果是一个形状为 (N, K) 的张量。

    示例:

    q = torch.tensor([[1, 2], [3, 4]])
    queue = torch.tensor([[5, 6, 7], [8, 9, 10]])
    result = torch.einsum('nc,ck->nk', [q, queue])
    # 等价于
    # result = torch.matmul(q, queue)
    # 结果: tensor([[21, 24, 27],
    #               [47, 54, 61]])
    

einsum的优势:

  1. 灵活性:可以用简洁的符号表示复杂的张量运算。
  2. 效率:在某些情况下比显式循环更高效。
  3. 可读性:一旦熟悉了符号,代码变得更易读。

这篇关于torch.einsum 爱因斯坦求和约定的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1107288

相关文章

用einsum实现MultiHeadAttention前向传播

einsum教程网站Einstein Summation in Numpy | Olexa Bilaniuk's IFT6266H16 Course Blog 编写训练模型 import tensorflow as tfclass Model(tf.keras.Model):def __init__(self, num_heads, model_dim):super().__init__

1 模拟——67. 二进制求和

1 模拟 67. 二进制求和 给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。 示例 1:输入:a = "11", b = "1"输出:"100"示例 2:输入:a = "1010", b = "1011"输出:"10101" 算法设计 可以从低位到高位(从后向前)计算,用一个变量carry记录进位,如果有字符没处理完或者有进位,则循环处理。两个字符串对

Win32函数调用约定(Calling Convention)

平常我们在C#中使用DllImportAttribute引入函数时,不指明函数调用约定(CallingConvention)这个参数,也可以正常调用。如FindWindow函数 [DllImport("user32.dll", EntryPoint="FindWindow", SetLastError = true)]public static extern IntPtr FindWindow

Leetcode67---二进制求和

https://leetcode.cn/problems/add-binary/description/ 给出的两个二进制,我们可以从最后开始往前运算。 给当前短的一位前面补充0即可。 class Solution {public String addBinary(String a, String b) {//给的就是二进制字符串 最后一位开始遍历 如果没有就补充0?StringBuil

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

torch.nn 与 torch.nn.functional的区别?

区别 PyTorch中torch.nn与torch.nn.functional的区别是:1.继承方式不同;2.可训练参数不同;3.实现方式不同;4.调用方式不同。 1.继承方式不同 torch.nn 中的模块大多数是通过继承torch.nn.Module 类来实现的,这些模块都是Python 类,需要进行实例化才能使用。而torch.nn.functional 中的函数是直接调用的,无需

UVa 10820 Send a Table (Farey数列欧拉函数求和)

这里先说一下欧拉函数的求法 先说一下筛选素数的方法 void Get_Prime(){ /*筛选素数法*/for(int i = 0; i < N; i++) vis[i] = 1;vis[0] = vis[1] = 0;for(int i = 2; i * i < N; i++)if(vis[i]){for(int j = i * i; j < N; j += i)vis[j] =

【hdu】敌兵布阵(线段树,更加结点,区间求和)

最近开始刷线段树,主要围绕notonlysuccess的线段树总结刷。 结点修改还是比较简单的,不需要什么懒惰标记,直接二分递归就可以了。 #include <iostream>#include <cstdlib>#include <cstdio>#include <string>#include <cstring>#include <cmath>#include <vecto

上海市计算机学会竞赛平台2024年7月月赛丙组求和问题

题目描述 给定 nn 个整数 a1,a2,…,ana1​,a2​,…,an​,请问这个序列最长有多少长的前缀,满足元素的和大于或等于 00?如果任何长度大于 00 的前缀之和都为负数,则输出 00 输入格式 第一行:单个整数表示 nn第二行:nn 个整数表示 a1,a2,…,ana1​,a2​,…,an​ 输出格式 单个整数:表示最长的前缀长度,使得前缀的和大于等于 00 数据范围

每日OJ_牛客_求和(递归深搜)

目录 牛客_求和(递归深搜) 解析代码 牛客_求和(递归深搜) 求和_好未来笔试题_牛客网 解析代码         递归中每次累加一个新的数,如果累加和大于等于目标,结束递归。此时如果累加和正好等于目标,则打印组合。向上回退搜索其它组合。此题本身就是一个搜索的过程,找到所有的组合。 #include <iostream>#include <cmath>#in