本文主要是介绍13,12_基本运算,add/minus/multiply/divide,矩阵相乘mm,matmul,pow/sqrt/rsqrt,exp/log近似值,统计属性,mean,sum,min,max,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
1.12.基本运算
1.12.1.add/minus/multiply/divide
1.12.2.矩阵相乘mm,matmul
1.12.3.pow/sqrt/rsqrt
1.12.4.exp/log
1.12.5.近似值floor、ceil、trunc、frac、round
1.12.6.现幅max、min、median、clamp
1.13.统计属性
1.13.1.norm
1.13.2.mean,sum,min,max,prod
1.13.3.max,argmin,argmax,topk,kthvalue
1.13.4.compare
1.12.基本运算
1.12.1.add/minus/multiply/divide
a + b = torch.add(a, b)
a - b = torch.sub(a, b)
a * b = torch.mul(a, b)
a / b = torch.div(a, b)
# -*- coding: UTF-8 -*-import torcha = torch.rand(3, 4)
b = torch.rand(4)
print(a)
"""
输出结果:
tensor([[0.8796, 0.9511, 0.1630, 0.0036],[0.4834, 0.2088, 0.3118, 0.7274],[0.8440, 0.3282, 0.4091, 0.4249]])
"""print(b)
"""
输出结果:
tensor([0.2553, 0.5917, 0.7143, 0.7302])
"""# 相加
# b会被广播
print(a + b)
"""
输出结果:
tensor([[1.1349, 1.5428, 0.8773, 0.7338],[0.7387, 0.8005, 1.0261, 1.4577],[1.0993, 0.9199, 1.1235, 1.1552]])
"""# 等价于上面相加
print(torch.add(a, b))
"""
输出结果:
tensor([[1.1349, 1.5428, 0.8773, 0.7338],[0.7387, 0.8005, 1.0261, 1.4577],[1.0993, 0.9199, 1.1235, 1.1552]])
"""# 比较两个是否相等
print(torch.all(torch.eq(a + b, torch.add(a, b))))
"""
输出结果:
torch.all判断是否所有元素都相等输出结果:
tensor(True)
"""
1.12.2.矩阵相乘mm,matmul
torch.mm(a, b) # 此方法只适用于2维
torch.matmul(a, b)
a @ b = torch.matmul(a, b) # 推荐使用此方法
用处:
降维:比如,[4, 784] @ [784, 512] = [4, 512]
大于2d的数据相乘:最后2个维度的数据相乘:[4, 3, 28, 64] @ [4, 3, 64, 32] = [4, 3, 28, 32]
前提是:除了最后两个维度满足相乘条件以外,其他维度要满足广播条件,比如此处的前面两个维度只能是[4, 3]和[4, 1]
# -*- coding: UTF-8 -*-import torcha = torch.randn(2, 3)
b = torch.randn(3, 2)
# 输出结果
print(a)
"""
输出结果:
tensor([[ 0.0935, -0.1704, 1.1908],[-0.2091, 0.0285, -0.5522]])
"""
print(b)
"""
输出结果:
tensor([[ 0.1315, -0.4669],[-0.1053, 0.9560],[ 0.0769, 0.9642]])
"""print(torch.mm(a, b))
"""
输出结果:
tensor([[ 0.1218, 0.9416],[-0.0730, -0.4076]])
"""
print(torch.matmul(a, b))
"""
输出结果:
tensor([[ 0.1218, 0.9416],[-0.0730, -0.4076]])
"""
1.12.3.pow/sqrt/rsqrt
# -*- coding: UTF-8 -*-import torcha = torch.full([2, 2], 3)
print(a)
"""
输出结果:
tensor([[3, 3],[3, 3]])
"""print(a.pow(2))
"""
输出结果:
tensor([[9, 9],[9, 9]])
"""aa = a ** 2
print(aa)
"""
输出结果:
tensor([[9, 9],[9, 9]])
"""## 平方根
print(aa ** (0.5))
"""
输出结果:
tensor([[3., 3.],[3., 3.]])
"""# 平方根
print(aa.pow(0.5))
"""
输出结果:
tensor([[3., 3.],[3., 3.]])
"""
1.12.4.exp/log
# -*- coding: UTF-8 -*-import torcha = torch.ones(2, 2)
print(a)
"""
输出结果:
tensor([[1., 1.],[1., 1.]])
"""# 自认底数e
print(torch.exp(a))
"""
输出结果:
tensor([[2.7183, 2.7183],[2.7183, 2.7183]])
"""# 对数
# 默认底数是e
# 可以更换为Log2、Log10
print(torch.log(a))
"""
输出结果:
tensor([[0., 0.],[0., 0.]])
"""
1.12.5.近似值floor、ceil、trunc、frac、round
a.floor() #向下取整
a.ceil() #向上取整
a.trunc() #保留整数部分:truncate,截取
a.frac() #保留小数部分:fraction, 小数
a.round() #四舍五入:round,大约
# -*- coding: UTF-8 -*-import torcha = torch.tensor(3.14)
print(a.floor(), a.ceil(), a.trunc(), a.frac())
"""
输出结果:
tensor(3.) tensor(4.) tensor(3.) tensor(0.1400)
"""a = torch.tensor(3.499)
print(a.round())
"""
输出结果:
tensor(3.)
"""a = torch.tensor(3.5)
print(a.round())
"""
输出结果:
tensor(4.)
"""
1.12.6.现幅max、min、median、clamp
a.max() # 最大值
a.min() # 最小值
a.median() # 中位数
a.clamp(10) #最小值限定为10 a.clamp(0, 10) #将数据限定在[0, 10],两边都是闭区间
# -*- coding: UTF-8 -*-import torchgrad = torch.rand(2, 3) * 15
print(grad)
"""
输出结果:
tensor([[ 4.7390, 7.9376, 12.8128],[12.1366, 1.4925, 9.5263]])
"""print(grad.max())
"""
输出结果:
tensor(12.8128)
"""print(grad.median())
"""
输出结果:
tensor(7.9376)
"""print(grad.clamp(10))
"""
输出结果:
tensor([[10.0000, 10.0000, 12.8128],[12.1366, 10.0000, 10.0000]])
"""print(grad.clamp(0, 10))
"""
输出结果:
tensor([[ 4.7390, 7.9376, 10.0000],[10.0000, 1.4925, 9.5263]])
"""
1.13.统计属性
1.13.1.norm
1.norm vs normalize and batch_norm是有区别的:norm是范数的意思,normalize、batch_norm是归一化
2.matrix norm和vector norm是有区别的
a = torch.full([8],1)
b = a.view(2,4)
c = a.view(2,2,2)
a.norm(1) # a tensor 的一范式
: tensor(8.)
b.norm(1)
: tensor(8.)
c.norm(1)
: tensor(8.)b.norm(2) # b tensor 的二范式
: tensor(2.8284)
b.norm(1,dim=1)
:tensor(4.,4.)
1.13.2.mean,sum,min,max,prod
对于argmin,argmax:如果不给出固定的dimension,会把tensor打平成dim=1,然后返回最小、最大的索引。
# -*- coding: UTF-8 -*-import torcha = torch.arange(8).view(2, 4).float()
print(a)
"""
输出结果:
tensor([[0., 1., 2., 3.],[4., 5., 6., 7.]])
"""print(a.min(), a.max(), a.mean(), a.prod(), a.sum(), a.argmin(), a.argmax())
"""
输出结果:
tensor(0.) tensor(7.) tensor(3.5000) tensor(0.) tensor(28.) tensor(0) tensor(7)
"""
1.13.3.max,argmin,argmax,topk,kthvalue
# -*- coding: UTF-8 -*-import torcha = torch.rand(4, 10)
print("----------------1---------------------")
print(a)
"""
输出结果:
tensor([[0.4550, 0.0754, 0.5295, 0.2976, 0.7861, 0.5620, 0.2705, 0.0929, 0.3207, 0.3191],[0.8027, 0.3193, 0.8842, 0.2734, 0.3881, 0.8242, 0.6090, 0.4655, 0.0993, 0.6304],[0.7399, 0.4701, 0.7231, 0.4278, 0.6317, 0.6905, 0.9834, 0.5210, 0.7772, 0.7630],[0.3206, 0.5491, 0.6806, 0.5545, 0.3620, 0.3515, 0.2682, 0.5013, 0.1984, 0.3038]])
"""print("----------------2---------------------")
print(a.max(dim=1))
"""
输出结果:
torch.return_types.max(
values=tensor([0.7861, 0.8842, 0.9834, 0.6806]),
indices=tensor([4, 2, 6, 2]))
"""print("----------------3---------------------")
print(a.argmax(dim=1))
"""
输出结果:
tensor([4, 2, 6, 2])
"""print("-----------------4--------------------")
print(a.max(dim=1, keepdim=True)) # 希望结果的维度(dim)和a保持一致
"""
输出结果:
torch.return_types.max(
values=tensor([[0.7861],[0.8842],[0.9834],[0.6806]]),
indices=tensor([[4],[2],[6],[2]]))
"""print("-----------------5--------------------")
print(a.argmax(dim=1, keepdim=True)) # 最大的值的维度,保持维度
"""
输出结果:
tensor([[4],[2],[6],[2]])
"""print("----------------6---------------------")
print(a.topk(3, dim=1)) # 每一维中的前3最大值,以及索引
"""
输出结果:
torch.return_types.topk(
values=tensor([[0.7861, 0.5620, 0.5295],[0.8842, 0.8242, 0.8027],[0.9834, 0.7772, 0.7630],[0.6806, 0.5545, 0.5491]]),
indices=tensor([[4, 5, 2],[2, 5, 0],[6, 8, 9],[2, 3, 1]]))
"""
print("----------------7---------------------")
print(a.topk(3, dim=1, largest=False)) # 每维度中,最小的3个值,以及索引
"""
输出结果:
torch.return_types.topk(
values=tensor([[0.0754, 0.0929, 0.2705],[0.0993, 0.2734, 0.3193],[0.4278, 0.4701, 0.5210],[0.1984, 0.2682, 0.3038]]),
indices=tensor([[1, 7, 6],[8, 3, 1],[3, 1, 7],[8, 6, 9]]))
"""print("----------------8---------------------")
print(a.kthvalue(8, dim=1)) # 每个维度中,第8个最大值的值,以及索引
"""
torch.return_types.kthvalue(
values=tensor([0.5295, 0.8027, 0.7630, 0.5491]),
indices=tensor([2, 0, 9, 1]))
"""
1.13.4.compare
, >=, <, <=, !=, ==
torch.eq(a, b) : 比较两个矩阵是否相等
# -*- coding: UTF-8 -*-import torcha = torch.rand(4, 10)
print("----------------1---------------------")
print(a)
"""
输出结果:
tensor([[0.5431, 0.2628, 0.1717, 0.2056, 0.0288, 0.9366, 0.3158, 0.7862, 0.0668, 0.9356],[0.8946, 0.1231, 0.8310, 0.3631, 0.1795, 0.9628, 0.9884, 0.2004, 0.1994, 0.6071],[0.8863, 0.5992, 0.7863, 0.1543, 0.3057, 0.0189, 0.0196, 0.0419, 0.1391, 0.2097],[0.7474, 0.1389, 0.6977, 0.7851, 0.6969, 0.6046, 0.8341, 0.5421, 0.3144, 0.8706]])
"""# 比较每个维度是否大于指定的值
print(a > 0.5)
"""
输出结果:
tensor([[ True, False, False, False, False, True, False, True, False, True],[ True, False, True, False, False, True, True, False, False, True],[ True, True, True, False, False, False, False, False, False, False],[ True, False, True, True, True, True, True, True, False, True]])
"""print(torch.gt(a, 0.5))
"""
输出结果:
tensor([[ True, False, False, False, False, True, False, True, False, True],[ True, False, True, False, False, True, True, False, False, True],[ True, True, True, False, False, False, False, False, False, False],[ True, False, True, True, True, True, True, True, False, True]])
"""print(a != 0)
"""
输出结果:
tensor([[True, True, True, True, True, True, True, True, True, True],[True, True, True, True, True, True, True, True, True, True],[True, True, True, True, True, True, True, True, True, True],[True, True, True, True, True, True, True, True, True, True]])
"""a = torch.ones(2, 3)
b = torch.randn(2, 3)
print(torch.eq(a, b))
"""
tensor([[False, False, False],[False, False, False]])
"""print(torch.eq(a, a))
"""
输出结果:
tensor([[True, True, True],[True, True, True]])
"""print(torch.equal(a, a))
"""
输出结果:
True
"""
这篇关于13,12_基本运算,add/minus/multiply/divide,矩阵相乘mm,matmul,pow/sqrt/rsqrt,exp/log近似值,统计属性,mean,sum,min,max的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!