pytorch中backward函数的参数gradient作用的数学过程

2023-10-28 15:20

本文主要是介绍pytorch中backward函数的参数gradient作用的数学过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch中backward函数的参数gradient作用的数学过程

    • 问题描述
    • 实例分析

本机器学习小白最近在学习pytorch,在学习.backward()函数的过程中一直不能理解参数gradient的作用,感觉相关资料中对它的解释过于简单,几乎忽略了相关数学过程。这里分享一下我的理解,希望能对有需要的同学有些帮助。

.backward()函数是pytorch用来实现反向传播计算的关键,不了解反向传播或者.backward()函数的可以看看下面的文章:

  • pytorch中backward()函数详解
  • pytorch的计算图
  • PyTorch的自动求导机制详细解析,PyTorch的核心魔法
  • 反向传播——通俗易懂
  • 深度学习——以图读懂反向传播

问题描述

官方文档对.backward()的说明是这样的(看英文头痛的同学可以忽略解释内容):
在这里插入图片描述
gradient的解释是这样的:
在这里插入图片描述
可以粗浅的认为backward()就是在根据计算图计算tensor的“梯度”,但是这个“梯度”只有在tensor是标量(scalar)时才是真正意义上的梯度。

gradient在实际使用的过程中,是一个与使用.backward()的tensor维度一致的tensor。官方文档中只提到如果使用.backward()的tensor是一个标量,就可以省略gradient参数,而如果这个tensor是向量,则必须指明gradient参数,但这个参数对计算的作用是什么,文档中说的很模糊(好像就没提到)。

看了一些文章之后发现,“如果是(向量)矩阵对(向量)矩阵求导(tensor对tensor求导),实际上是先求出Jacobian矩阵中每一个元素的梯度值(每一个元素的梯度值的求解过程对应上面的计算图的求解方法),然后将这个Jacobian矩阵与grad_tensors参数对应的矩阵进行对应的点乘,得到最终的结果。”(来源:pytorch中backward()函数详解)
这里提到的grad_tensors参数就是现在的gradient参数。

所以本质上,gradient参数在向量与向量的求导中起作用,而backward()在这种情况下求得的各个元素的梯度实际上并不是Jacobian,而是Jacobian与gradient的乘积。

以下结合一些例子说明backward()函数的计算结果。

实例分析

来源:PyTorch的自动求导机制详细解析,PyTorch的核心魔法

import torchx = torch.tensor([0.0, 2.0, 8.0], requires_grad = True)y = torch.tensor([5.0, 1.0, 7.0], requires_grad = True)z = x * yz.backward(torch.FloatTensor([1.0, 1.0, 1.0]))

运行完之后查看z分别关于x和y的梯度可以发现:

>>>x.grad.data
tensor([5., 1., 7.])
>>>y.grad.data
tensor([0., 2., 8.])

实际上上述代码的计算结果可以这么理解:
x = ( x 1 x 2 x 3 ) = ( 0.0 2.0 8.0 ) y = ( y 1 y 2 y 3 ) = ( 5.0 1.0 7.0 ) x = \begin{pmatrix} x_1 & x_2 & x_3\end{pmatrix} = \begin{pmatrix} 0.0 & 2.0 & 8.0\end{pmatrix}\\ y = \begin{pmatrix} y_1 & y_2 & y_3\end{pmatrix} = \begin{pmatrix} 5.0 & 1.0 & 7.0\end{pmatrix} x=(x1x2x3)=(0.02.08.0)y=(y1y2y3)=(5.01.07.0) z z z向量则是 x x x y y y每项相乘得到的向量(不是点乘也不是叉乘)
z = ( x 1 y 1 x 2 y 2 x 3 y 3 ) z = \begin{pmatrix} x_1y_1 & x_2y_2 & x_3y_3\end{pmatrix} z=(x1y1x2y2x3y3)那么 z z z关于 x x x的Jacobian就是
J = ( ∂ z ∂ x 1 ∂ z ∂ x 2 ∂ z ∂ x 3 ) = ( ∂ z 1 ∂ x 1 ∂ z 1 ∂ x 2 ∂ z 1 ∂ x 3 ∂ z 2 ∂ x 1 ∂ z 2 ∂ x 2 ∂ z 2 ∂ x 3 ∂ z 3 ∂ x 1 ∂ z 3 ∂ x 2 ∂ z 3 ∂ x 3 ) = ( y 1 0 0 0 y 2 0 0 0 y 3 ) J=\begin{pmatrix} \frac{\partial z}{\partial x_1} & \frac{\partial z}{\partial x_2} & \frac{\partial z}{\partial x_3} \end{pmatrix} = \begin{pmatrix} \frac{\partial z_1}{\partial x_1} & \frac{\partial z_1}{\partial x_2} & \frac{\partial z_1}{\partial x_3} \\ \frac{\partial z_2}{\partial x_1} & \frac{\partial z_2}{\partial x_2} & \frac{\partial z_2}{\partial x_3} \\ \frac{\partial z_3}{\partial x_1} & \frac{\partial z_3}{\partial x_2} & \frac{\partial z_3}{\partial x_3} \end{pmatrix} =\begin{pmatrix} y_1 & 0 & 0\\ 0 & y_2 & 0\\ 0 & 0 & y_3 \end{pmatrix} J=(x1zx2zx3z)=x1z1x1z2x1z3x2z1x2z2x2z3x3z1x3z2x3z3=y1000y2000y3而我们引入的gradient参数是一个向量(这里用 v v v表示)
v = ( 1.0 1.0 1.0 ) T v = \begin{pmatrix} 1.0 & 1.0 &1.0 \end{pmatrix}^{T} v=(1.01.01.0)T然后将 J J J v v v相乘,我们就得到了我们看到的x.grad.data的结果:
J v = ( y 1 0 0 0 y 2 0 0 0 y 3 ) ( 1.0 1.0 1.0 ) = ( y 1 y 2 y 3 ) = ( 5.0 1.0 7.0 ) Jv =\begin{pmatrix} y_1 & 0 & 0\\ 0 & y_2 & 0\\ 0 & 0 & y_3 \end{pmatrix} \begin{pmatrix} 1.0 \\ 1.0 \\ 1.0 \end{pmatrix} = \begin{pmatrix} y_1 \\ y_2 \\ y_3 \end{pmatrix} = \begin{pmatrix} 5.0 \\ 1.0 \\ 7.0 \end{pmatrix} Jv=y1000y2000y31.01.01.0=y1y2y3=5.01.07.0y.grad.data的结果这里就不演示了,因为这个例子里 x x x y y y具有对称性,只需要把结果的 y y y换成 x x x就得到了y.grad.data的结果,所以我们可以看到x.grad.data和y.grad.data在数值上就是 y y y x x x的值。

这篇关于pytorch中backward函数的参数gradient作用的数学过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/294116

相关文章

MySQL 中的 CAST 函数详解及常见用法

《MySQL中的CAST函数详解及常见用法》CAST函数是MySQL中用于数据类型转换的重要函数,它允许你将一个值从一种数据类型转换为另一种数据类型,本文给大家介绍MySQL中的CAST... 目录mysql 中的 CAST 函数详解一、基本语法二、支持的数据类型三、常见用法示例1. 字符串转数字2. 数字

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客

Python函数作用域示例详解

《Python函数作用域示例详解》本文介绍了Python中的LEGB作用域规则,详细解析了变量查找的四个层级,通过具体代码示例,展示了各层级的变量访问规则和特性,对python函数作用域相关知识感兴趣... 目录一、LEGB 规则二、作用域实例2.1 局部作用域(Local)2.2 闭包作用域(Enclos

Java进程异常故障定位及排查过程

《Java进程异常故障定位及排查过程》:本文主要介绍Java进程异常故障定位及排查过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、故障发现与初步判断1. 监控系统告警2. 日志初步分析二、核心排查工具与步骤1. 进程状态检查2. CPU 飙升问题3. 内存

Java内存分配与JVM参数详解(推荐)

《Java内存分配与JVM参数详解(推荐)》本文详解JVM内存结构与参数调整,涵盖堆分代、元空间、GC选择及优化策略,帮助开发者提升性能、避免内存泄漏,本文给大家介绍Java内存分配与JVM参数详解,... 目录引言JVM内存结构JVM参数概述堆内存分配年轻代与老年代调整堆内存大小调整年轻代与老年代比例元空

SpringBoot整合liteflow的详细过程

《SpringBoot整合liteflow的详细过程》:本文主要介绍SpringBoot整合liteflow的详细过程,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋...  liteflow 是什么? 能做什么?总之一句话:能帮你规范写代码逻辑 ,编排并解耦业务逻辑,代码

MySQL count()聚合函数详解

《MySQLcount()聚合函数详解》MySQL中的COUNT()函数,它是SQL中最常用的聚合函数之一,用于计算表中符合特定条件的行数,本文给大家介绍MySQLcount()聚合函数,感兴趣的朋... 目录核心功能语法形式重要特性与行为如何选择使用哪种形式?总结深入剖析一下 mysql 中的 COUNT

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

python常用的正则表达式及作用

《python常用的正则表达式及作用》正则表达式是处理字符串的强大工具,Python通过re模块提供正则表达式支持,本文给大家介绍python常用的正则表达式及作用详解,感兴趣的朋友跟随小编一起看看吧... 目录python常用正则表达式及作用基本匹配模式常用正则表达式示例常用量词边界匹配分组和捕获常用re

MySQL中的InnoDB单表访问过程

《MySQL中的InnoDB单表访问过程》:本文主要介绍MySQL中的InnoDB单表访问过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、环境3、访问类型【1】const【2】ref【3】ref_or_null【4】range【5】index【6】