本文主要是介绍Pytorch入门(番外)点乘与相乘,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
今天看到一行pytorch的代码
import torch
from torch.autograd import Variabletensor = torch.FloatTensor([[1,2],[3,4]])
variable = Variable(tensor, requires_grad=True)
v_out = torch.mean(variable*variable)
很理所当然的理解为两个矩阵相乘,但是打印输出看的时候觉得不对
tensor([[ 1., 4.],[ 9., 16.]], grad_fn=<MulBackward0>)
这里明显做了一个点乘
那如何才能让这两个variable变量做矩阵的乘法呢
print(torch.mm(variable,variable))
tensor([[ 7., 10.],[15., 22.]], grad_fn=<MmBackward>)
嗯,这样就可以了
这篇关于Pytorch入门(番外)点乘与相乘的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!