PyTorch手动梯度清零

2023-11-01 12:50
文章标签 pytorch 手动 梯度 清零

本文主要是介绍PyTorch手动梯度清零,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简单介绍

Pytorch中,每个batch训练完后需要使用Variable.grad.zero_()进行梯度清零(其中Variable为变量名,性质为torch.tensor).在Pytorch中,之所以需要手动进行梯度清零,而不是选择自动清零,是因为这种方式可以让使用者自由选择梯度清零的时机,具有更高的灵活性.例如选择训练每N个batch后再进行梯度更新和清零,这相当于将原来的batch_size扩大为N×batch_size.因为原先是每个batch_size训练完后直接更新,而现在变为N个batch_size训练完才更新,相当于将N个batch_size合为了一组.这样可以让使用者使用较低的配置,跑较高的batch_size.

具体用法

第一种,直接梯度清零

直接梯度清零
上图中的用法为常规用法,每完成一个batch,更新一次梯度,完成一次训练.

第二种,batch叠加

batch叠加
上图中关键步骤有两步
第一步为:

loss = loss / accumulation_setps

其中accumulation_setps即是我在简单介绍中提到的N,用来控制梯度更新时机.由于损失函数均需要求平均,如果没有上图中的代码,相当于accumulation_setps个batch的损失值简单相加(类似于batch中m个数据的损失相加后并不求平均),这显然是不合适的.
第二步为:

if((i+1)%accumulation_setps) == 0optimizer.step()optimizer.zero_grad()

该步意义简单明了,即是每accumulation_setps个batch梯度清一次零,并完成一次训练.
两张图片均来源于一篇知乎问答,具体在哪找不到了.

这篇关于PyTorch手动梯度清零的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/323098

相关文章

✨机器学习笔记(二)—— 线性回归、代价函数、梯度下降

1️⃣线性回归(linear regression) f w , b ( x ) = w x + b f_{w,b}(x) = wx + b fw,b​(x)=wx+b 🎈A linear regression model predicting house prices: 如图是机器学习通过监督学习运用线性回归模型来预测房价的例子,当房屋大小为1250 f e e t 2 feet^

AI学习指南深度学习篇-带动量的随机梯度下降法的基本原理

AI学习指南深度学习篇——带动量的随机梯度下降法的基本原理 引言 在深度学习中,优化算法被广泛应用于训练神经网络模型。随机梯度下降法(SGD)是最常用的优化算法之一,但单独使用SGD在收敛速度和稳定性方面存在一些问题。为了应对这些挑战,动量法应运而生。本文将详细介绍动量法的原理,包括动量的概念、指数加权移动平均、参数更新等内容,最后通过实际示例展示动量如何帮助SGD在参数更新过程中平稳地前进。

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

AI学习指南深度学习篇-带动量的随机梯度下降法简介

AI学习指南深度学习篇 - 带动量的随机梯度下降法简介 引言 在深度学习的广阔领域中,优化算法扮演着至关重要的角色。它们不仅决定了模型训练的效率,还直接影响到模型的最终表现之一。随着神经网络模型的不断深化和复杂化,传统的优化算法在许多领域逐渐暴露出其不足之处。带动量的随机梯度下降法(Momentum SGD)应运而生,并被广泛应用于各类深度学习模型中。 在本篇文章中,我们将深入探讨带动量的随

【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

3、cnn基础 卷积神经网络 输入层 —输入图片矩阵 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片 卷积层 —特征提取 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆 卷积操作 激活层 —加强特征 池化层 —压缩数据 全连接层 —进行分类 输出层 —输出分类概率 4、基于LeNet

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

pytorch计算网络参数量和Flops

from torchsummary import summarysummary(net, input_size=(3, 256, 256), batch_size=-1) 输出的参数是除以一百万(/1000000)M, from fvcore.nn import FlopCountAnalysisinputs = torch.randn(1, 3, 256, 256).cuda()fl

什么是GPT-3的自回归架构?为什么GPT-3无需梯度更新和微调

文章目录 知识回顾GPT-3的自回归架构何为自回归架构为什么架构会影响任务表现自回归架构的局限性与双向模型的对比小结 为何无需梯度更新和微调为什么不需要怎么做到不需要 🍃作者介绍:双非本科大四网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发,目前开始人工智能领域相关知识的学习 🦅个人主页:@逐梦苍穹 📕所属专栏:人工智能 🌻gitee地址:x

ASP.NET手动触发页面验证控件事件

开发环境:.NET Framework 3.5.1 sp1 参考文章: http://www.codeproject.com/KB/aspnet/JavascriptValidation.aspx http://msdn.microsoft.com/zh-cn/library/aa479045.aspx http://www.cnblogs.com/minsentinel/archive/

【VSCode v1.93.0】手动配置远程remote-ssh

开发环境 VS Code版本:1.93.0 (Windows) Ubuntu版本:20.04 使用VS Code 插件remote-ssh远程访问Ubuntu服务器中的代码,若Ubuntu无法联网,在连接的时候会报错: Could not establish connection to "xxxx": Failed to download VS Code Server(Failed to