PyTorch 的自动求导与计算图

2024-09-01 04:52
文章标签 计算 自动 pytorch 求导

本文主要是介绍PyTorch 的自动求导与计算图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在深度学习中,模型的训练过程本质上是通过梯度下降算法不断优化损失函数。为了高效地计算梯度,PyTorch 提供了强大的自动求导机制,这一机制依赖于“计算图”(Computational Graph)的概念。

1. 什么是计算图?

计算图是一种有向无环图(DAG),其中每个节点表示操作或变量,边表示数据的流动。简单来说,计算图是一个将复杂计算分解为一系列基本操作的图表。每个节点(通常称为“张量”)是一个数据单元,而边表示这些数据单元之间的计算关系。

例如,假设你有一个简单的函数 y = 2x + 1,这个函数可以表示为一个非常简单的计算图:

    x  ----->  2x  ----->  2x + 1

在这个图中,x 是一个输入张量,2x 是第一个操作节点,2x + 1 是第二个操作节点。PyTorch 会自动构建这个计算图,随着你对张量进行操作,图会动态扩展。

2. PyTorch 中的计算图

在 PyTorch 中,计算图是动态构建的。这意味着每次运行前向传播时,PyTorch 都会根据实际的操作构建计算图。这与其他静态图框架(如 TensorFlow 的早期版本)不同,后者需要先定义完整的图,然后再运行计算。

动态计算图的优点在于它灵活且易于调试。你可以在代码中使用 Python 的控制流(如条件语句、循环等),计算图会根据运行时的实际路径生成。

来看一个实际的例子:

import torch# 创建一个张量,并指定需要计算梯度
x = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)# 对张量进行操作
y = 2 * x + 1

在这段代码中,我们创建了一个名为 x 的张量,并通过 requires_grad=True 指定它是需要计算梯度的变量。这一步非常重要,因为只有 requires_grad 设置为 True 的张量,PyTorch 才会在计算图中跟踪它们的操作。

接下来,y = 2 * x + 1 执行了两个操作:先将 x 乘以 2,再加 1。每个操作都会在计算图中创建一个节点,表示计算的过程。这个计算图的结构可以描述为:

    x  ----->  2x  ----->  2x + 1

每个操作都被记录在计算图中,为反向传播过程做好准备。

3. 反向传播与梯度计算

当我们执行完前向计算后,接下来要做的就是通过反向传播计算梯度。梯度是指损失函数相对于输入变量的导数,用于指示在给定点处损失函数如何变化。

假设我们想计算 yx 的梯度。在 PyTorch 中,我们通过调用 backward() 方法来实现:

# 对 y 求和,然后执行反向传播
y.sum().backward()

y.sum() 是一个标量函数,将 y 的所有元素相加。这一步非常重要,因为在反向传播中,只有标量的梯度才能正确地传递。如果 y 不是标量,PyTorch 会对其进行求和,以确保反向传播的正确性。

执行 backward() 后,PyTorch 会自动计算 yx 的梯度,并将结果存储在 x.grad 中:

print(x.grad)  # 输出 [2.0, 2.0, 2.0, 2.0]

在这个例子中,dy/dx = 2,所以 x.grad 中的每个元素的值都是 2

4. 自动求导背后的数学原理

要理解自动求导,首先需要理解基本的微积分概念。导数反映了函数的变化率,是梯度下降算法的核心。

4.1 导数的概念

导数表示一个函数在某个点的瞬时变化率。如果你有一个简单的线性函数 y = 2x + 1,其导数是 2。这意味着,无论 x 的值是多少,y 的变化率都是常数 2

4.2 链式法则

链式法则是反向传播算法的基础。它告诉我们如何计算复合函数的导数。假设我们有两个函数 u = g(x)y = f(u),那么 yx 的导数可以通过链式法则计算:

dy/dx = (dy/du) * (du/dx)

在计算图中,链式法则对应于从输出节点到输入节点的梯度传递。每一步都遵循链式法则,将梯度从一层传递到下一层,最终计算出输入变量的梯度。

5. 复杂操作与控制流中的自动求导

PyTorch 的动态计算图不仅支持简单的操作,还可以处理更加复杂的操作和控制流。

5.1 非线性操作

非线性操作,如平方、指数运算等,使得计算图更加复杂。考虑下面的例子:

z = y ** 2  # z = (2x + 1) ^ 2

在这个例子中,计算图变为:

    x  ----->  2x  ----->  2x + 1  ----->  (2x + 1) ^ 2

此时,如果你对 z 进行反向传播,PyTorch 会首先计算 dz/dy,然后利用链式法则乘以 dy/dx,最终得到 dz/dx。通过调用 z.sum().backward(),你可以得到 zx 的梯度。

z.sum().backward()
print(x.grad)  # 输出 [12.0, 16.0, 20.0, 24.0]

在这里,x.grad 的值为 4x + 2,这就是 z = (2x + 1)^2x 的导数。

5.2 控制流中的求导

PyTorch 的自动求导机制同样可以处理控制流,比如条件语句和循环。对于动态计算图,控制流可以使得每次前向计算的图结构不同,但 PyTorch 依然能够正确计算梯度。

def my_func(a):if a.item() > 1:return a ** 2else:return a * 3x = torch.tensor(2.0, requires_grad=True)
y = my_func(x)
y.backward()
print(x.grad)  # 输出 4.0

在这个例子中,my_func 根据 a 的值执行不同的操作。如果 a > 1,则返回 a 的平方;否则,返回 a 的三倍。由于 x 的值为 2.0,所以计算的结果是 y = 4.0,而 yx 的导数为 4.0

6. 多变量函数的自动求导

在实际应用中,许多函数是多变量的。这时,PyTorch 同样可以计算每个变量的梯度。

x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y = x1 ** 2 + x2 ** 3
y.backward()

在这个例子中,yx1x2 的函数。调用 backward() 后,x1.gradx2.grad 将分别存储 yx1x2 的导数。

print(x1.grad)  # 输出 2.0
print(x2.grad)  # 输出 12.0

x1.grad 的值为 2 * x1 = 2.0,而 x2.grad 的值为 3 * x2^2 = 12.0

7. detach() 的用途与计算图的修改

在某些情况下,你可能不希望某个张量参与计算图的反向传播。detach() 函数可以从计算图中分离出一个张量,使得它在反向传播时不影响梯度的计算。

x = torch.tensor(3.0, requires_grad=True)
y = x ** 2
z = y.detach() * 2  # z 与 y 无关,反向传播时不计算 z 的梯度
z.backward()
print(x.grad)  # 输出 None

在这里,由于 z 是从 y 中分离出来的,反向传播时 x.grad 不会受到 z 的影响。

此外,with torch.no_grad() 也可以用于临时停止计算图的构建,通常用于模型推理阶段。

8. 实际应用:深度学习中的梯度更新

自动求导在深度学习中的一个典型应用是梯度更新。在训练过程中,模型的参数会通过反向传播计算梯度,并使用优化器(如 SGD、Adam 等)更新这些参数。PyTorch 的 torch.optim 模块提供了多种优化器,可以自动利用计算出的梯度进行参数更新。

import torch.optim as optim# 创建一个简单的线性模型
model = torch.nn.Linear(1, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)# 输入数据和目标
x = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
y_true = torch.tensor([[2.0], [4.0], [6.0]])# 前向传播
y_pred = model(x)# 计算损失
loss = torch.nn.functional.mse_loss(y_pred, y_true)# 反向传播
loss.backward()# 更新参数
optimizer.step()

在这段代码中,我们创建了一个简单的线性模型,并使用 MSE 作为损失函数。通过反向传播计算梯度后,优化器会自动更新模型的参数,使损失逐渐减小。

9. 总结

PyTorch 的自动求导机制是深度学习中非常重要且强大的工具。它基于计算图自动计算梯度,极大地简化了模型训练中的梯度计算过程。无论是简单的线性函数还是复杂的神经网络,PyTorch 都能通过动态计算图和自动求导机制高效地进行梯度计算和参数优化。在实际应用中,掌握这些基础知识可以帮助我们更好地理解和优化深度学习模型。

这篇关于PyTorch 的自动求导与计算图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1125998

相关文章

poj 1113 凸包+简单几何计算

题意: 给N个平面上的点,现在要在离点外L米处建城墙,使得城墙把所有点都包含进去且城墙的长度最短。 解析: 韬哥出的某次训练赛上A出的第一道计算几何,算是大水题吧。 用convexhull算法把凸包求出来,然后加加减减就A了。 计算见下图: 好久没玩画图了啊好开心。 代码: #include <iostream>#include <cstdio>#inclu

uva 1342 欧拉定理(计算几何模板)

题意: 给几个点,把这几个点用直线连起来,求这些直线把平面分成了几个。 解析: 欧拉定理: 顶点数 + 面数 - 边数= 2。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstring>#include <cmath>#inc

uva 11178 计算集合模板题

题意: 求三角形行三个角三等分点射线交出的内三角形坐标。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstring>#include <cmath>#include <stack>#include <vector>#include <

XTU 1237 计算几何

题面: Magic Triangle Problem Description: Huangriq is a respectful acmer in ACM team of XTU because he brought the best place in regional contest in history of XTU. Huangriq works in a big compa

基于51单片机的自动转向修复系统的设计与实现

文章目录 前言资料获取设计介绍功能介绍设计清单具体实现截图参考文献设计获取 前言 💗博主介绍:✌全网粉丝10W+,CSDN特邀作者、博客专家、CSDN新星计划导师,一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP430/AVR等单片机设计 主要对象是咱们电子相关专业的大学生,希望您们都共创辉煌!✌💗 👇🏻 精彩专栏 推荐订阅👇🏻 单片机

Python3 BeautifulSoup爬虫 POJ自动提交

POJ 提交代码采用Base64加密方式 import http.cookiejarimport loggingimport urllib.parseimport urllib.requestimport base64from bs4 import BeautifulSoupfrom submitcode import SubmitCodeclass SubmitPoj():de

音视频入门基础:WAV专题(10)——FFmpeg源码中计算WAV音频文件每个packet的pts、dts的实现

一、引言 从文章《音视频入门基础:WAV专题(6)——通过FFprobe显示WAV音频文件每个数据包的信息》中我们可以知道,通过FFprobe命令可以打印WAV音频文件每个packet(也称为数据包或多媒体包)的信息,这些信息包含该packet的pts、dts: 打印出来的“pts”实际是AVPacket结构体中的成员变量pts,是以AVStream->time_base为单位的显

计算数组的斜率,偏移,R2

模拟Excel中的R2的计算。         public bool fnCheckRear_R2(List<double[]> lRear, int iMinRear, int iMaxRear, ref double dR2)         {             bool bResult = true;             int n = 0;             dou

GPU 计算 CMPS224 2021 学习笔记 02

并行类型 (1)任务并行 (2)数据并行 CPU & GPU CPU和GPU拥有相互独立的内存空间,需要在两者之间相互传输数据。 (1)分配GPU内存 (2)将CPU上的数据复制到GPU上 (3)在GPU上对数据进行计算操作 (4)将计算结果从GPU复制到CPU上 (5)释放GPU内存 CUDA内存管理API (1)分配内存 cudaErro

Java - BigDecimal 计算分位(百分位)

日常开发中,如果使用数据库来直接查询一组数据的分位数,就比较简单,直接使用对应的函数就可以了,例如:         PERCENT_RANK() OVER(PARTITION BY 分组列名 ORDER BY 目标列名) AS 目标列名_分位数         如果是需要在代码逻辑部分进行分位数的计算,就需要我们自己写一个工具类来支持计算了 import static ja