本文主要是介绍[pytorch基础操作] 矩阵batch乘法大全(dot,* 和 mm,bmm,@,matmul),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
- 逐元素相乘
- torch.dot
- *
- 矩阵乘法
- torch.mm
- torch.bmm
- @ 和 torch.matmul
逐元素相乘
逐元素相乘是指对应位置上的元素相乘,要求张量的形状相同
。
torch.dot
按位相乘torch.dot
:计算两个张量的点积(内积),只支持1D张量(向量),不支持broadcast。
import torch# 创建两个向量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 计算点积
result = torch.dot(a, b)
print(result) # 输出: tensor(32)
*
*
: 逐元素相乘,适用于任何维度的张量,要求张量的形状相同。
import torch# 创建两个张量
a = torch.randn(2, 3, 4)
b = torch.randn(2, 3, 4)# 逐元素相乘
result = a * b
print(result.shape)
矩阵乘法
矩阵乘法,执行矩阵乘法,前行乘后列,要求第一个矩阵的列数(tensor1.shape[-1])
与第二个矩阵的行数(tensor2.shape[-2])
相等。如shape=(n,r)
乘shape=(r,m)
torch.mm
torch.mm
: 执行两个矩阵的乘法,适用于2D张量(矩阵)(h,w)/(seq_len,dim),不支持broadcast。
import torch# 创建两个矩阵
a = torch.rand(2,3)
b = torch.rand(3,2)# 计算矩阵乘法
result = torch.mm(a, b)
print(result.shape) # [2,2]
torch.bmm
torch.bmm
: 执行两个批次矩阵的乘法,适用于3D张量(b,h,w)/(b,seq_len,dim),不支持broadcast。
import torch# 创建两个批次矩阵
batch1 = torch.randn(10, 3, 4) # 10个3x4的矩阵
batch2 = torch.randn(10, 4, 5) # 10个4x5的矩阵# 计算批次矩阵乘法
result = torch.bmm(batch1, batch2)
print(result.shape) # [10, 3, 5]
@ 和 torch.matmul
@
或 torch.matmul
: 两者完全等价,执行任意维度两个张量的矩阵乘法,支持张量的broadcast广播规则。
import torch# 创建两个张量
a = torch.randn(2, 8, 128, 64)
b = torch.randn(2, 8, 64, 128)# 使用 @ 运算符进行矩阵乘法
result = a @ b
print(result.shape) # [2, 8, 128, 128]# 使用 torch.matmul 进行矩阵乘法
result = torch.matmul(a, b)
print(result.shape) # [2, 8, 128, 128]
这篇关于[pytorch基础操作] 矩阵batch乘法大全(dot,* 和 mm,bmm,@,matmul)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!