【张量乘法】pytorch中的mul、dot、mm、matmul

2024-05-28 09:12

本文主要是介绍【张量乘法】pytorch中的mul、dot、mm、matmul,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

张量的乘法是pytorch等神经网络开发框架中最常见、最基本的操作之一。

1,torch.mul

对应位置的元素相乘。mul即表示张量中对应位置元素的相乘,也是最容易理解的乘法。

import torch
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
res = torch.mul(a, b)
print(res)# [[ 5, 12], [21, 32]]

2, torch.dot

表示两个1D向量的点乘:(注意:torch.dot和np.dot用法差异较大
t o r c h . d o t ( [ x 1 , y 1 ] , [ x 2 , y 2 ] ) = x 1 ⋅ x 2 + y 1 ⋅ y 2 (1) torch.dot([x_1,y_1], [x_2,y_2]) =x_1\cdot x_2+ y_1\cdot y_2 \tag{1} torch.dot([x1,y1],[x2,y2])=x1x2+y1y2(1)
两个1D-vector在torch.dot后变成一个标量。实验代码:

res = torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1]))
print(res)
# 7

torch.dot使用有以下要求:

  1. 只针对1D向量;
  2. 向量必须等长;

3,torch.mm

表示矩阵乘法, ( m , n ) × ( n , p ) → ( m , p ) (m,n) \times (n,p) \rightarrow (m, p) (m,n)×(n,p)(m,p)

import torch
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
res = torch.mm(a, b)
print(res)
# [[19, 22], [43, 50]]

4,torch.matmul

也表示矩阵乘,在输入2个1D向量时,表现出与torch.dot一样的效果:

res = torch.matmul(torch.tensor([2, 3]), torch.tensor([2, 1]))
print(res)
# 7

输入2个2D向量时,表达的是矩阵乘法,与torch.mm有一样的效果。

import torch# 1D x 1D
res = torch.matmul(torch.tensor([2, 3]), torch.tensor([2, 1]))
print(res)
# 7# 2D x 2D
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
res = torch.matmul(a, b)
print(res)
# [[19, 22], [43, 50]]# 1D x 2D
a = torch.tensor([1, 2])
b = torch.tensor([[5, 6], [7, 8]])
res = torch.matmul(a, b)
print(res)
# [19, 22]# 2D x 1D
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([5, 6])
res = torch.matmul(a, b)
print(res)
# [17, 39]# (j, 1, n, n) x (k, n, n) -> (j, k, n, n)
# (j, 1, n, m) x (k, m, p) -> (j, k, n, p)

这篇关于【张量乘法】pytorch中的mul、dot、mm、matmul的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1010146

相关文章

hdu 6198 dfs枚举找规律+矩阵乘法

number number number Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) Problem Description We define a sequence  F : ⋅   F0=0,F1=1 ; ⋅   Fn=Fn

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

3、cnn基础 卷积神经网络 输入层 —输入图片矩阵 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片 卷积层 —特征提取 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆 卷积操作 激活层 —加强特征 池化层 —压缩数据 全连接层 —进行分类 输出层 —输出分类概率 4、基于LeNet

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

pytorch计算网络参数量和Flops

from torchsummary import summarysummary(net, input_size=(3, 256, 256), batch_size=-1) 输出的参数是除以一百万(/1000000)M, from fvcore.nn import FlopCountAnalysisinputs = torch.randn(1, 3, 256, 256).cuda()fl

高精度加法,乘法,阶乘

#include <iostream>#include <map>#include <string>#include <algorithm>using namespace std;const int Max = 50000;string str1,str2;/***********乘法***********/void chenfa(){cin >> str1>>str2;int a

Python(TensorFlow和PyTorch)两种显微镜成像重建算法模型(显微镜学)

🎯要点 🎯受激发射损耗显微镜算法模型:🖊恢复嘈杂二维和三维图像 | 🖊模型架构:恢复上下文信息和超分辨率图像 | 🖊使用嘈杂和高信噪比的图像训练模型 | 🖊准备半合成训练集 | 🖊优化沙邦尼尔损失和边缘损失 | 🖊使用峰值信噪比、归一化均方误差和多尺度结构相似性指数量化结果 | 🎯训练荧光显微镜模型和对抗网络图形转换模型 🍪语言内容分比 🍇Python图像归一化

Pytorch环境搭建时的各种问题

1 问题 1.一直soving environment,跳不出去。网络解决方案有:配置清华源,更新conda等,没起作用。2.下载完后,有3个要done的东西,最后那个exe开头的(可能吧),总是报错。网络解决方案有:用管理员权限打开prompt等,没起作用。3.有时候配置完源,安装包的时候显示什么https之类的东西,去c盘的用户那个文件夹里找到".condarc"文件把里面的网址都改成htt

【DL--03】深度学习基本概念—张量

张量 TensorFlow中的中心数据单位是张量。张量由一组成形为任意数量的数组的原始值组成。张量的等级是其维数。以下是张量的一些例子: 3 # a rank 0 tensor; this is a scalar with shape [][1. ,2., 3.] # a rank 1 tensor; this is a vector with shape [3][[1., 2., 3.]

【hive 日期转换】Hive中yyyymmdd和yyyy-mm-dd日期之间的切换

方法1: from_unixtime+ unix_timestamp--20171205转成2017-12-05 select from_unixtime(unix_timestamp('20171205','yyyymmdd'),'yyyy-mm-dd') from dual;--2017-12-05转成20171205select from_unixtime(unix_timestamp