PyTorch 的自动求导与计算图

2024-09-01 04:52
文章标签 计算 自动 pytorch 求导

本文主要是介绍PyTorch 的自动求导与计算图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在深度学习中,模型的训练过程本质上是通过梯度下降算法不断优化损失函数。为了高效地计算梯度,PyTorch 提供了强大的自动求导机制,这一机制依赖于“计算图”(Computational Graph)的概念。

1. 什么是计算图?

计算图是一种有向无环图(DAG),其中每个节点表示操作或变量,边表示数据的流动。简单来说,计算图是一个将复杂计算分解为一系列基本操作的图表。每个节点(通常称为“张量”)是一个数据单元,而边表示这些数据单元之间的计算关系。

例如,假设你有一个简单的函数 y = 2x + 1,这个函数可以表示为一个非常简单的计算图:

    x  ----->  2x  ----->  2x + 1

在这个图中,x 是一个输入张量,2x 是第一个操作节点,2x + 1 是第二个操作节点。PyTorch 会自动构建这个计算图,随着你对张量进行操作,图会动态扩展。

2. PyTorch 中的计算图

在 PyTorch 中,计算图是动态构建的。这意味着每次运行前向传播时,PyTorch 都会根据实际的操作构建计算图。这与其他静态图框架(如 TensorFlow 的早期版本)不同,后者需要先定义完整的图,然后再运行计算。

动态计算图的优点在于它灵活且易于调试。你可以在代码中使用 Python 的控制流(如条件语句、循环等),计算图会根据运行时的实际路径生成。

来看一个实际的例子:

import torch# 创建一个张量,并指定需要计算梯度
x = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)# 对张量进行操作
y = 2 * x + 1

在这段代码中,我们创建了一个名为 x 的张量,并通过 requires_grad=True 指定它是需要计算梯度的变量。这一步非常重要,因为只有 requires_grad 设置为 True 的张量,PyTorch 才会在计算图中跟踪它们的操作。

接下来,y = 2 * x + 1 执行了两个操作:先将 x 乘以 2,再加 1。每个操作都会在计算图中创建一个节点,表示计算的过程。这个计算图的结构可以描述为:

    x  ----->  2x  ----->  2x + 1

每个操作都被记录在计算图中,为反向传播过程做好准备。

3. 反向传播与梯度计算

当我们执行完前向计算后,接下来要做的就是通过反向传播计算梯度。梯度是指损失函数相对于输入变量的导数,用于指示在给定点处损失函数如何变化。

假设我们想计算 yx 的梯度。在 PyTorch 中,我们通过调用 backward() 方法来实现:

# 对 y 求和,然后执行反向传播
y.sum().backward()

y.sum() 是一个标量函数,将 y 的所有元素相加。这一步非常重要,因为在反向传播中,只有标量的梯度才能正确地传递。如果 y 不是标量,PyTorch 会对其进行求和,以确保反向传播的正确性。

执行 backward() 后,PyTorch 会自动计算 yx 的梯度,并将结果存储在 x.grad 中:

print(x.grad)  # 输出 [2.0, 2.0, 2.0, 2.0]

在这个例子中,dy/dx = 2,所以 x.grad 中的每个元素的值都是 2

4. 自动求导背后的数学原理

要理解自动求导,首先需要理解基本的微积分概念。导数反映了函数的变化率,是梯度下降算法的核心。

4.1 导数的概念

导数表示一个函数在某个点的瞬时变化率。如果你有一个简单的线性函数 y = 2x + 1,其导数是 2。这意味着,无论 x 的值是多少,y 的变化率都是常数 2

4.2 链式法则

链式法则是反向传播算法的基础。它告诉我们如何计算复合函数的导数。假设我们有两个函数 u = g(x)y = f(u),那么 yx 的导数可以通过链式法则计算:

dy/dx = (dy/du) * (du/dx)

在计算图中,链式法则对应于从输出节点到输入节点的梯度传递。每一步都遵循链式法则,将梯度从一层传递到下一层,最终计算出输入变量的梯度。

5. 复杂操作与控制流中的自动求导

PyTorch 的动态计算图不仅支持简单的操作,还可以处理更加复杂的操作和控制流。

5.1 非线性操作

非线性操作,如平方、指数运算等,使得计算图更加复杂。考虑下面的例子:

z = y ** 2  # z = (2x + 1) ^ 2

在这个例子中,计算图变为:

    x  ----->  2x  ----->  2x + 1  ----->  (2x + 1) ^ 2

此时,如果你对 z 进行反向传播,PyTorch 会首先计算 dz/dy,然后利用链式法则乘以 dy/dx,最终得到 dz/dx。通过调用 z.sum().backward(),你可以得到 zx 的梯度。

z.sum().backward()
print(x.grad)  # 输出 [12.0, 16.0, 20.0, 24.0]

在这里,x.grad 的值为 4x + 2,这就是 z = (2x + 1)^2x 的导数。

5.2 控制流中的求导

PyTorch 的自动求导机制同样可以处理控制流,比如条件语句和循环。对于动态计算图,控制流可以使得每次前向计算的图结构不同,但 PyTorch 依然能够正确计算梯度。

def my_func(a):if a.item() > 1:return a ** 2else:return a * 3x = torch.tensor(2.0, requires_grad=True)
y = my_func(x)
y.backward()
print(x.grad)  # 输出 4.0

在这个例子中,my_func 根据 a 的值执行不同的操作。如果 a > 1,则返回 a 的平方;否则,返回 a 的三倍。由于 x 的值为 2.0,所以计算的结果是 y = 4.0,而 yx 的导数为 4.0

6. 多变量函数的自动求导

在实际应用中,许多函数是多变量的。这时,PyTorch 同样可以计算每个变量的梯度。

x1 = torch.tensor(1.0, requires_grad=True)
x2 = torch.tensor(2.0, requires_grad=True)
y = x1 ** 2 + x2 ** 3
y.backward()

在这个例子中,yx1x2 的函数。调用 backward() 后,x1.gradx2.grad 将分别存储 yx1x2 的导数。

print(x1.grad)  # 输出 2.0
print(x2.grad)  # 输出 12.0

x1.grad 的值为 2 * x1 = 2.0,而 x2.grad 的值为 3 * x2^2 = 12.0

7. detach() 的用途与计算图的修改

在某些情况下,你可能不希望某个张量参与计算图的反向传播。detach() 函数可以从计算图中分离出一个张量,使得它在反向传播时不影响梯度的计算。

x = torch.tensor(3.0, requires_grad=True)
y = x ** 2
z = y.detach() * 2  # z 与 y 无关,反向传播时不计算 z 的梯度
z.backward()
print(x.grad)  # 输出 None

在这里,由于 z 是从 y 中分离出来的,反向传播时 x.grad 不会受到 z 的影响。

此外,with torch.no_grad() 也可以用于临时停止计算图的构建,通常用于模型推理阶段。

8. 实际应用:深度学习中的梯度更新

自动求导在深度学习中的一个典型应用是梯度更新。在训练过程中,模型的参数会通过反向传播计算梯度,并使用优化器(如 SGD、Adam 等)更新这些参数。PyTorch 的 torch.optim 模块提供了多种优化器,可以自动利用计算出的梯度进行参数更新。

import torch.optim as optim# 创建一个简单的线性模型
model = torch.nn.Linear(1, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)# 输入数据和目标
x = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
y_true = torch.tensor([[2.0], [4.0], [6.0]])# 前向传播
y_pred = model(x)# 计算损失
loss = torch.nn.functional.mse_loss(y_pred, y_true)# 反向传播
loss.backward()# 更新参数
optimizer.step()

在这段代码中,我们创建了一个简单的线性模型,并使用 MSE 作为损失函数。通过反向传播计算梯度后,优化器会自动更新模型的参数,使损失逐渐减小。

9. 总结

PyTorch 的自动求导机制是深度学习中非常重要且强大的工具。它基于计算图自动计算梯度,极大地简化了模型训练中的梯度计算过程。无论是简单的线性函数还是复杂的神经网络,PyTorch 都能通过动态计算图和自动求导机制高效地进行梯度计算和参数优化。在实际应用中,掌握这些基础知识可以帮助我们更好地理解和优化深度学习模型。

这篇关于PyTorch 的自动求导与计算图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1125998

相关文章

python实现自动登录12306自动抢票功能

《python实现自动登录12306自动抢票功能》随着互联网技术的发展,越来越多的人选择通过网络平台购票,特别是在中国,12306作为官方火车票预订平台,承担了巨大的访问量,对于热门线路或者节假日出行... 目录一、遇到的问题?二、改进三、进阶–展望总结一、遇到的问题?1.url-正确的表头:就是首先ur

Spring使用@Retryable实现自动重试机制

《Spring使用@Retryable实现自动重试机制》在微服务架构中,服务之间的调用可能会因为一些暂时性的错误而失败,例如网络波动、数据库连接超时或第三方服务不可用等,在本文中,我们将介绍如何在Sp... 目录引言1. 什么是 @Retryable?2. 如何在 Spring 中使用 @Retryable

如何用Java结合经纬度位置计算目标点的日出日落时间详解

《如何用Java结合经纬度位置计算目标点的日出日落时间详解》这篇文章主详细讲解了如何基于目标点的经纬度计算日出日落时间,提供了在线API和Java库两种计算方法,并通过实际案例展示了其应用,需要的朋友... 目录前言一、应用示例1、天安门升旗时间2、湖南省日出日落信息二、Java日出日落计算1、在线API2

使用 Python 和 LabelMe 实现图片验证码的自动标注功能

《使用Python和LabelMe实现图片验证码的自动标注功能》文章介绍了如何使用Python和LabelMe自动标注图片验证码,主要步骤包括图像预处理、OCR识别和生成标注文件,通过结合Pa... 目录使用 python 和 LabelMe 实现图片验证码的自动标注环境准备必备工具安装依赖实现自动标注核心

QT实现TCP客户端自动连接

《QT实现TCP客户端自动连接》这篇文章主要为大家详细介绍了QT中一个TCP客户端自动连接的测试模型,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录版本 1:没有取消按钮 测试效果测试代码版本 2:有取消按钮测试效果测试代码版本 1:没有取消按钮 测试效果缺陷:无法手动停

poj 1113 凸包+简单几何计算

题意: 给N个平面上的点,现在要在离点外L米处建城墙,使得城墙把所有点都包含进去且城墙的长度最短。 解析: 韬哥出的某次训练赛上A出的第一道计算几何,算是大水题吧。 用convexhull算法把凸包求出来,然后加加减减就A了。 计算见下图: 好久没玩画图了啊好开心。 代码: #include <iostream>#include <cstdio>#inclu

uva 1342 欧拉定理(计算几何模板)

题意: 给几个点,把这几个点用直线连起来,求这些直线把平面分成了几个。 解析: 欧拉定理: 顶点数 + 面数 - 边数= 2。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstring>#include <cmath>#inc

uva 11178 计算集合模板题

题意: 求三角形行三个角三等分点射线交出的内三角形坐标。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstring>#include <cmath>#include <stack>#include <vector>#include <

XTU 1237 计算几何

题面: Magic Triangle Problem Description: Huangriq is a respectful acmer in ACM team of XTU because he brought the best place in regional contest in history of XTU. Huangriq works in a big compa

基于51单片机的自动转向修复系统的设计与实现

文章目录 前言资料获取设计介绍功能介绍设计清单具体实现截图参考文献设计获取 前言 💗博主介绍:✌全网粉丝10W+,CSDN特邀作者、博客专家、CSDN新星计划导师,一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP430/AVR等单片机设计 主要对象是咱们电子相关专业的大学生,希望您们都共创辉煌!✌💗 👇🏻 精彩专栏 推荐订阅👇🏻 单片机