python中pytorch的广播机制——Broadcasting

2023-10-11 02:36

本文主要是介绍python中pytorch的广播机制——Broadcasting,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

广播机制

numpy 在算术运算期间采用“广播”来处理具有不同形状的 array ,即将较小的阵列在较大的阵列上“广播”,以便它们具有兼容的形状。Broadcasting是一种没有copy数据的expand

  • 不过两个维度不相同,在前面插入维度1
  • 扩张维度1到相同的维度

例如:Feature maps:[4,32,14,14]
Bias:[32,1,1]=>[1,32,1,1]=>[4,32,14,14]

A:[32,1,1]=>[1,32,1,1]=>[4,32,14,14]
B:[4,32,14,14]
这里就可以进行相同维度的相加

image


比如说一个[4,1]+[1,2]
那么这个[4,1]可以再复制列变为[4,2]
[1,2]可以再复制4行变为[4,2]

首先用1将那个小的维度的tensor扩展成大的维度相同的维度,然后将1扩张成两者的相同维度,如果有两个维度不相同,并且都不是1的话,则不能broadcasting

 

广播规则

当对两个 array 进行操作时,numpy 会逐元素比较它们的形状。从尾(即最右边)维度开始,然后向左逐渐比较。只有当两个维度 1)相等 or 2)其中一个维度是1 时,这两个维度才会被认为是兼容。

如果不满足这些条件,则会抛出 ValueError:operands could not be broadcast together 异常,表明 array 的形状不兼容。最终结果 array 的每个维度尽可能不为 1 ,是两个操作数各个维度中较大的值 。

例如,有一个 256x256x3 的 RGB 值图片 array ,需要将图像中的每种颜色缩放不同的值,此时可以将图像乘以具有 3 个值的一维 array 。根据广播规则排列这两个 array 的尾维度大小,是兼容的:

 图片(3d array): 256 x 256 x 3
缩放(1d array):             3
结果(3d array): 256 x 256 x 3

当比较的任一维度是 1 时,使用另一个,也就是说,大小为 1 的维度被拉伸或“复制”以匹配另一个维度。
在以下示例中,A 和 B 数组都有长度为 1 的维度,在广播操作期间扩展为更大的大小:

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
result (4d array):  8 x 7 x 6 x 5

以二维为例,更加方便的解释“广播”:
已知 a.shape 是(5,1),b.shape 是(1,6),c.shape 是(6,),d.shape 是(), d 是一个标量, a, b, c,和 d 都可以“广播”到维度 (5,6);

a “广播”为一个 (5,6) array ,其中 a[:,0] 被“广播”到其他列,
b “广播”为一个 (5,6) array ,其中 b[0,:] 被广播到其他行,
c 类似于 (1,6) array ,其中 c[:] 广播到每一行,
d 是标量,“广播”为 (5,6) array ,其中每个元素都一样,重复d值。
 

A      (2d array):      2 x 1
B      (3d array):  8 x 4 x 3 # 倒数第二个维度不兼容
>>> a = np.array([[ 0.0,  0.0,  0.0],
...               [10.0, 10.0, 10.0],
...               [20.0, 20.0, 20.0],
...               [30.0, 30.0, 30.0]])
>>> b = np.array([1.0, 2.0, 3.0])
>>> a + b
array([[  1.,   2.,   3.],[11.,  12.,  13.],[21.,  22.,  23.],[31.,  32.,  33.]])
>>> b = np.array([1.0, 2.0, 3.0, 4.0])
>>> a + b
Traceback (most recent call last):
ValueError: operands could not be broadcast together with shapes (4,3) (4,)

 

 

在某些情况下,广播会拉伸两个 array 以形成一个大于任何一个初始 array 的结果 array 。 

>>> a = np.array([0.0, 10.0, 20.0, 30.0])
>>> b = np.array([1.0, 2.0, 3.0])
>>> a[:, np.newaxis] + b
array([[ 1.,   2.,   3.],[11.,  12.,  13.],[21.,  22.,  23.],[31.,  32.,  33.]])

 

newaxis 运算符将新轴插入到 a 中,使其成为二维 4x1 array 。将 4x1 array 与形状为 (3,) 的 b 组合,产生一个 4x3 array 。 

这里注意要都从右端进行匹配:
A:[                     ]
B:          [           ]
就是这样补充
我们看个例子吧:

a=torch.randn(2,3,4)
b=torch.randn(2,3)
a+b
#The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 2

image


但是这样是可以的

image


也就是(2,3,4)+(2,3)是不可以的,(2,3,4)+(3,4)是可以的,因为他们是右看齐的。

Situation 1:
▪ [4, 32, 14, 14]
▪ [1, 32, 1, 1] => [4, 32, 14, 14]

Situation 2
▪ [4, 32, 14, 14]
▪ [14, 14] => [1, 1, 14, 14] => [4, 32, 14, 14]

Situation 3
▪ [4, 32, 14, 14]
▪ [2, 32, 14, 14]
▪ Dim 0 has dim, can NOT insert and expand to same
▪ Dim 0 has distinct dim, NOT size 1
▪ NOT broadcasting-able

Situation 4
▪ [4, 32, 14, 14]
▪ [4, 32, 14]
这样是不行的,因为我们要右看齐,match from
last dim

Situation 5
▪ [4, 3, 32, 32]
▪ + [32, 32]
▪ + [3, 1, 1]
▪ + [1, 1, 1, 1]
这都是可以的

这篇关于python中pytorch的广播机制——Broadcasting的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/184911

相关文章

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

【Python编程】Linux创建虚拟环境并配置与notebook相连接

1.创建 使用 venv 创建虚拟环境。例如,在当前目录下创建一个名为 myenv 的虚拟环境: python3 -m venv myenv 2.激活 激活虚拟环境使其成为当前终端会话的活动环境。运行: source myenv/bin/activate 3.与notebook连接 在虚拟环境中,使用 pip 安装 Jupyter 和 ipykernel: pip instal

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

nudepy,一个有趣的 Python 库!

更多资料获取 📚 个人网站:ipengtao.com 大家好,今天为大家分享一个有趣的 Python 库 - nudepy。 Github地址:https://github.com/hhatto/nude.py 在图像处理和计算机视觉应用中,检测图像中的不适当内容(例如裸露图像)是一个重要的任务。nudepy 是一个基于 Python 的库,专门用于检测图像中的不适当内容。该

pip-tools:打造可重复、可控的 Python 开发环境,解决依赖关系,让代码更稳定

在 Python 开发中,管理依赖关系是一项繁琐且容易出错的任务。手动更新依赖版本、处理冲突、确保一致性等等,都可能让开发者感到头疼。而 pip-tools 为开发者提供了一套稳定可靠的解决方案。 什么是 pip-tools? pip-tools 是一组命令行工具,旨在简化 Python 依赖关系的管理,确保项目环境的稳定性和可重复性。它主要包含两个核心工具:pip-compile 和 pip

【编程底层思考】垃圾收集机制,GC算法,垃圾收集器类型概述

Java的垃圾收集(Garbage Collection,GC)机制是Java语言的一大特色,它负责自动管理内存的回收,释放不再使用的对象所占用的内存。以下是对Java垃圾收集机制的详细介绍: 一、垃圾收集机制概述: 对象存活判断:垃圾收集器定期检查堆内存中的对象,判断哪些对象是“垃圾”,即不再被任何引用链直接或间接引用的对象。内存回收:将判断为垃圾的对象占用的内存进行回收,以便重新使用。

【Tools】大模型中的自注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 自注意力机制(Self-Attention)是一种在Transformer等大模型中经常使用的注意力机制。该机制通过对输入序列中的每个元素计算与其他元素之间的相似性,