13,12_基本运算,add/minus/multiply/divide,矩阵相乘mm,matmul,pow/sqrt/rsqrt,exp/log近似值,统计属性,mean,sum,min,max

本文主要是介绍13,12_基本运算,add/minus/multiply/divide,矩阵相乘mm,matmul,pow/sqrt/rsqrt,exp/log近似值,统计属性,mean,sum,min,max,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.12.基本运算
1.12.1.add/minus/multiply/divide
1.12.2.矩阵相乘mm,matmul
1.12.3.pow/sqrt/rsqrt
1.12.4.exp/log
1.12.5.近似值floor、ceil、trunc、frac、round
1.12.6.现幅max、min、median、clamp
1.13.统计属性
1.13.1.norm
1.13.2.mean,sum,min,max,prod
1.13.3.max,argmin,argmax,topk,kthvalue
1.13.4.compare

1.12.基本运算

1.12.1.add/minus/multiply/divide

a + b = torch.add(a, b)
a - b = torch.sub(a, b)
a * b = torch.mul(a, b)
a / b = torch.div(a, b)

# -*- coding: UTF-8 -*-import torcha = torch.rand(3, 4)
b = torch.rand(4)
print(a)
"""
输出结果:
tensor([[0.8796, 0.9511, 0.1630, 0.0036],[0.4834, 0.2088, 0.3118, 0.7274],[0.8440, 0.3282, 0.4091, 0.4249]])
"""print(b)
"""
输出结果:
tensor([0.2553, 0.5917, 0.7143, 0.7302])
"""# 相加
# b会被广播
print(a + b)
"""
输出结果:
tensor([[1.1349, 1.5428, 0.8773, 0.7338],[0.7387, 0.8005, 1.0261, 1.4577],[1.0993, 0.9199, 1.1235, 1.1552]])
"""# 等价于上面相加
print(torch.add(a, b))
"""
输出结果:
tensor([[1.1349, 1.5428, 0.8773, 0.7338],[0.7387, 0.8005, 1.0261, 1.4577],[1.0993, 0.9199, 1.1235, 1.1552]])
"""# 比较两个是否相等
print(torch.all(torch.eq(a + b, torch.add(a, b))))
"""
输出结果:
torch.all判断是否所有元素都相等输出结果:
tensor(True)
"""

1.12.2.矩阵相乘mm,matmul

torch.mm(a, b) # 此方法只适用于2维
torch.matmul(a, b)
a @ b = torch.matmul(a, b) # 推荐使用此方法
用处:
降维:比如,[4, 784] @ [784, 512] = [4, 512]
大于2d的数据相乘:最后2个维度的数据相乘:[4, 3, 28, 64] @ [4, 3, 64, 32] = [4, 3, 28, 32]
前提是:除了最后两个维度满足相乘条件以外,其他维度要满足广播条件,比如此处的前面两个维度只能是[4, 3]和[4, 1]

# -*- coding: UTF-8 -*-import torcha = torch.randn(2, 3)
b = torch.randn(3, 2)
# 输出结果
print(a)
"""
输出结果:
tensor([[ 0.0935, -0.1704,  1.1908],[-0.2091,  0.0285, -0.5522]])
"""
print(b)
"""
输出结果:
tensor([[ 0.1315, -0.4669],[-0.1053,  0.9560],[ 0.0769,  0.9642]])
"""print(torch.mm(a, b))
"""
输出结果:
tensor([[ 0.1218,  0.9416],[-0.0730, -0.4076]])
"""
print(torch.matmul(a, b))
"""
输出结果:
tensor([[ 0.1218,  0.9416],[-0.0730, -0.4076]])
"""

1.12.3.pow/sqrt/rsqrt

# -*- coding: UTF-8 -*-import torcha = torch.full([2, 2], 3)
print(a)
"""
输出结果:
tensor([[3, 3],[3, 3]])
"""print(a.pow(2))
"""
输出结果:
tensor([[9, 9],[9, 9]])
"""aa = a ** 2
print(aa)
"""
输出结果:
tensor([[9, 9],[9, 9]])
"""## 平方根
print(aa ** (0.5))
"""
输出结果:
tensor([[3., 3.],[3., 3.]])
"""# 平方根
print(aa.pow(0.5))
"""
输出结果:
tensor([[3., 3.],[3., 3.]])
"""

1.12.4.exp/log

# -*- coding: UTF-8 -*-import torcha = torch.ones(2, 2)
print(a)
"""
输出结果:
tensor([[1., 1.],[1., 1.]])
"""# 自认底数e
print(torch.exp(a))
"""
输出结果:
tensor([[2.7183, 2.7183],[2.7183, 2.7183]])
"""# 对数
# 默认底数是e
# 可以更换为Log2、Log10
print(torch.log(a))
"""
输出结果:
tensor([[0., 0.],[0., 0.]])
"""

1.12.5.近似值floor、ceil、trunc、frac、round

a.floor() #向下取整
a.ceil() #向上取整
a.trunc() #保留整数部分:truncate,截取
a.frac() #保留小数部分:fraction, 小数
a.round() #四舍五入:round,大约

# -*- coding: UTF-8 -*-import torcha = torch.tensor(3.14)
print(a.floor(),  a.ceil(),  a.trunc(),  a.frac())
"""
输出结果:
tensor(3.) tensor(4.) tensor(3.) tensor(0.1400)
"""a = torch.tensor(3.499)
print(a.round())
"""
输出结果:
tensor(3.)
"""a = torch.tensor(3.5)
print(a.round())
"""
输出结果:
tensor(4.)
"""

1.12.6.现幅max、min、median、clamp

a.max() # 最大值
a.min() # 最小值
a.median() # 中位数
a.clamp(10) #最小值限定为10 a.clamp(0, 10) #将数据限定在[0, 10],两边都是闭区间

# -*- coding: UTF-8 -*-import torchgrad = torch.rand(2, 3) * 15
print(grad)
"""
输出结果:
tensor([[ 4.7390,  7.9376, 12.8128],[12.1366,  1.4925,  9.5263]])
"""print(grad.max())
"""
输出结果:
tensor(12.8128)
"""print(grad.median())
"""
输出结果:
tensor(7.9376)
"""print(grad.clamp(10))
"""
输出结果:
tensor([[10.0000, 10.0000, 12.8128],[12.1366, 10.0000, 10.0000]])
"""print(grad.clamp(0, 10))
"""
输出结果:
tensor([[ 4.7390,  7.9376, 10.0000],[10.0000,  1.4925,  9.5263]])
"""

1.13.统计属性

1.13.1.norm

1.norm vs normalize and batch_norm是有区别的:norm是范数的意思,normalize、batch_norm是归一化
2.matrix norm和vector norm是有区别的

a = torch.full([8],1)
b = a.view(2,4)
c = a.view(2,2,2)
a.norm(1) # a tensor 的一范式
: tensor(8.)
b.norm(1)
: tensor(8.)
c.norm(1)
: tensor(8.)b.norm(2) # b tensor 的二范式 
: tensor(2.8284)
b.norm(1,dim=1)
:tensor(4.,4.)

1.13.2.mean,sum,min,max,prod

对于argmin,argmax:如果不给出固定的dimension,会把tensor打平成dim=1,然后返回最小、最大的索引。

# -*- coding: UTF-8 -*-import torcha = torch.arange(8).view(2, 4).float()
print(a)
"""
输出结果:
tensor([[0., 1., 2., 3.],[4., 5., 6., 7.]])
"""print(a.min(), a.max(), a.mean(), a.prod(), a.sum(), a.argmin(), a.argmax())
"""
输出结果:
tensor(0.) tensor(7.) tensor(3.5000) tensor(0.) tensor(28.) tensor(0) tensor(7)
"""

1.13.3.max,argmin,argmax,topk,kthvalue

# -*- coding: UTF-8 -*-import torcha = torch.rand(4, 10)
print("----------------1---------------------")
print(a)
"""
输出结果:
tensor([[0.4550, 0.0754, 0.5295, 0.2976, 0.7861, 0.5620, 0.2705, 0.0929, 0.3207, 0.3191],[0.8027, 0.3193, 0.8842, 0.2734, 0.3881, 0.8242, 0.6090, 0.4655, 0.0993, 0.6304],[0.7399, 0.4701, 0.7231, 0.4278, 0.6317, 0.6905, 0.9834, 0.5210, 0.7772, 0.7630],[0.3206, 0.5491, 0.6806, 0.5545, 0.3620, 0.3515, 0.2682, 0.5013, 0.1984, 0.3038]])
"""print("----------------2---------------------")
print(a.max(dim=1))
"""
输出结果:
torch.return_types.max(
values=tensor([0.7861, 0.8842, 0.9834, 0.6806]),
indices=tensor([4, 2, 6, 2]))
"""print("----------------3---------------------")
print(a.argmax(dim=1))
"""
输出结果:
tensor([4, 2, 6, 2])
"""print("-----------------4--------------------")
print(a.max(dim=1, keepdim=True))          # 希望结果的维度(dim)和a保持一致
"""
输出结果:
torch.return_types.max(
values=tensor([[0.7861],[0.8842],[0.9834],[0.6806]]),
indices=tensor([[4],[2],[6],[2]]))
"""print("-----------------5--------------------")
print(a.argmax(dim=1, keepdim=True))             # 最大的值的维度,保持维度
"""
输出结果:
tensor([[4],[2],[6],[2]])
"""print("----------------6---------------------")
print(a.topk(3, dim=1))           # 每一维中的前3最大值,以及索引
"""
输出结果:
torch.return_types.topk(
values=tensor([[0.7861, 0.5620, 0.5295],[0.8842, 0.8242, 0.8027],[0.9834, 0.7772, 0.7630],[0.6806, 0.5545, 0.5491]]),
indices=tensor([[4, 5, 2],[2, 5, 0],[6, 8, 9],[2, 3, 1]]))
"""
print("----------------7---------------------")
print(a.topk(3, dim=1, largest=False))           # 每维度中,最小的3个值,以及索引
"""
输出结果:
torch.return_types.topk(
values=tensor([[0.0754, 0.0929, 0.2705],[0.0993, 0.2734, 0.3193],[0.4278, 0.4701, 0.5210],[0.1984, 0.2682, 0.3038]]),
indices=tensor([[1, 7, 6],[8, 3, 1],[3, 1, 7],[8, 6, 9]]))
"""print("----------------8---------------------")
print(a.kthvalue(8, dim=1))              # 每个维度中,第8个最大值的值,以及索引
"""
torch.return_types.kthvalue(
values=tensor([0.5295, 0.8027, 0.7630, 0.5491]),
indices=tensor([2, 0, 9, 1]))
"""

1.13.4.compare

, >=, <, <=, !=, ==

torch.eq(a, b) : 比较两个矩阵是否相等

# -*- coding: UTF-8 -*-import torcha = torch.rand(4, 10)
print("----------------1---------------------")
print(a)
"""
输出结果:
tensor([[0.5431, 0.2628, 0.1717, 0.2056, 0.0288, 0.9366, 0.3158, 0.7862, 0.0668, 0.9356],[0.8946, 0.1231, 0.8310, 0.3631, 0.1795, 0.9628, 0.9884, 0.2004, 0.1994, 0.6071],[0.8863, 0.5992, 0.7863, 0.1543, 0.3057, 0.0189, 0.0196, 0.0419, 0.1391, 0.2097],[0.7474, 0.1389, 0.6977, 0.7851, 0.6969, 0.6046, 0.8341, 0.5421, 0.3144, 0.8706]])
"""# 比较每个维度是否大于指定的值
print(a > 0.5)
"""
输出结果:
tensor([[ True, False, False, False, False,  True, False,  True, False,  True],[ True, False,  True, False, False,  True,  True, False, False,  True],[ True,  True,  True, False, False, False, False, False, False, False],[ True, False,  True,  True,  True,  True,  True,  True, False,  True]])
"""print(torch.gt(a, 0.5))
"""
输出结果:
tensor([[ True, False, False, False, False,  True, False,  True, False,  True],[ True, False,  True, False, False,  True,  True, False, False,  True],[ True,  True,  True, False, False, False, False, False, False, False],[ True, False,  True,  True,  True,  True,  True,  True, False,  True]])
"""print(a != 0)
"""
输出结果:
tensor([[True, True, True, True, True, True, True, True, True, True],[True, True, True, True, True, True, True, True, True, True],[True, True, True, True, True, True, True, True, True, True],[True, True, True, True, True, True, True, True, True, True]])
"""a = torch.ones(2, 3)
b = torch.randn(2, 3)
print(torch.eq(a, b))
"""
tensor([[False, False, False],[False, False, False]])
"""print(torch.eq(a, a))
"""
输出结果:
tensor([[True, True, True],[True, True, True]])
"""print(torch.equal(a, a))
"""
输出结果:
True
"""

这篇关于13,12_基本运算,add/minus/multiply/divide,矩阵相乘mm,matmul,pow/sqrt/rsqrt,exp/log近似值,统计属性,mean,sum,min,max的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/918153

相关文章

Java进阶13讲__第12讲_1/2

多线程、线程池 1.  线程概念 1.1  什么是线程 1.2  线程的好处 2.   创建线程的三种方式 注意事项 2.1  继承Thread类 2.1.1 认识  2.1.2  编码实现  package cn.hdc.oop10.Thread;import org.slf4j.Logger;import org.slf4j.LoggerFactory

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

内核启动时减少log的方式

内核引导选项 内核引导选项大体上可以分为两类:一类与设备无关、另一类与设备有关。与设备有关的引导选项多如牛毛,需要你自己阅读内核中的相应驱动程序源码以获取其能够接受的引导选项。比如,如果你想知道可以向 AHA1542 SCSI 驱动程序传递哪些引导选项,那么就查看 drivers/scsi/aha1542.c 文件,一般在前面 100 行注释里就可以找到所接受的引导选项说明。大多数选项是通过"_

uva 575 Skew Binary(位运算)

求第一个以(2^(k+1)-1)为进制的数。 数据不大,可以直接搞。 代码: #include <stdio.h>#include <string.h>const int maxn = 100 + 5;int main(){char num[maxn];while (scanf("%s", num) == 1){if (num[0] == '0')break;int len =

hdu 4565 推倒公式+矩阵快速幂

题意 求下式的值: Sn=⌈ (a+b√)n⌉%m S_n = \lceil\ (a + \sqrt{b}) ^ n \rceil\% m 其中: 0<a,m<215 0< a, m < 2^{15} 0<b,n<231 0 < b, n < 2^{31} (a−1)2<b<a2 (a-1)^2< b < a^2 解析 令: An=(a+b√)n A_n = (a +

最大流=最小割=最小点权覆盖集=sum-最大点权独立集

二分图最小点覆盖和最大独立集都可以转化为最大匹配求解。 在这个基础上,把每个点赋予一个非负的权值,这两个问题就转化为:二分图最小点权覆盖和二分图最大点权独立集。   二分图最小点权覆盖     从x或者y集合中选取一些点,使这些点覆盖所有的边,并且选出来的点的权值尽可能小。 建模:     原二分图中的边(u,v)替换为容量为INF的有向边(u,v),设立源点s和汇点t

hdu 6198 dfs枚举找规律+矩阵乘法

number number number Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) Problem Description We define a sequence  F : ⋅   F0=0,F1=1 ; ⋅   Fn=Fn

ImportError: cannot import name ‘print_log‘ from ‘logging‘

mmcv升级到2.+后删除了很多 解决 查FAQ文档,找到 添加到mmcv.utils下即可

DAY16:什么是慢查询,导致的原因,优化方法 | undo log、redo log、binlog的用处 | MySQL有哪些锁

目录 什么是慢查询,导致的原因,优化方法 undo log、redo log、binlog的用处  MySQL有哪些锁   什么是慢查询,导致的原因,优化方法 数据库查询的执行时间超过指定的超时时间时,就被称为慢查询。 导致的原因: 查询语句比较复杂:查询涉及多个表,包含复杂的连接和子查询,可能导致执行时间较长。查询数据量大:当查询的数据量庞大时,即使查询本身并不复杂,也可能导致

【Java中的位运算和逻辑运算详解及其区别】

Java中的位运算和逻辑运算详解及其区别 在 Java 编程中,位运算和逻辑运算是常见的两种操作类型。位运算用于操作整数的二进制位,而逻辑运算则是处理布尔值 (boolean) 的运算。本文将详细讲解这两种运算及其主要区别,并给出相应示例。 应用场景了解 位运算和逻辑运算的设计初衷源自计算机底层硬件和逻辑运算的需求,它们分别针对不同的处理对象和场景。以下是它们设计的初始目的简介: