PyTorch-Lightning:trining_step的自动优化

2024-04-12 09:44

本文主要是介绍PyTorch-Lightning:trining_step的自动优化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • PyTorch-Lightning:trining_step的自动优化
      • 总结:
    • class _ AutomaticOptimization()
      • def run
      • def _make_closure
      • def _training_step
        • class ClosureResult():
          • def from_training_step_output
      • class Closure

PyTorch-Lightning:trining_step的自动优化

使用PyTorch-Lightning时,在trining_step定义损失,在没有定义损失,没有任何返回的情况下没有报错,在定义一个包含loss的多个元素字典返回时,也可以正常训练,那么到底lightning是怎么完成训练过程的。

总结:

在自动优化中,training_step必须返回一个tensor或者dict或者None(跳过),对于简单的使用,在training_step可以return一个tensor会作为Loss回传,也可以return一个字典,其中必须包括key"loss",字典中的"loss"会提取出来作为Loss回传,具体过程主要包含在lightning\pytorch\loop\sautomatic.py中的_ AutomaticOptimization()类。

在这里插入图片描述

class _ AutomaticOptimization()

实现自动优化(前向,梯度清零,后向,optimizer step)

在training_epoch_loop中会调用这个类的run函数。

def run

首先通过 _make_closure得到一个closure,详见def _make_closure,最后返回一个字典,如果我们在training_step只return了一个loss tensor则字典只有一个’loss’键值对,如果return了一个字典,则包含其他键值对。

可以看到调用了_ optimizer_step,_ optimizer_step经过层层调用,最后会调用torch默认的optimizer.zero_grad,而我们通过 make_closure得到的closure作为参数传入,具体而言是调用了closure类的_ call __()方法。

def run(self, optimizer: Optimizer, batch_idx: int, kwargs: OrderedDict) -> _OUTPUTS_TYPE:closure = self._make_closure(kwargs, optimizer, batch_idx)if (# when the strategy handles accumulation, we want to always call the optimizer stepnot self.trainer.strategy.handles_gradient_accumulation and self.trainer.fit_loop._should_accumulate()):# For gradient accumulation# -------------------# calculate loss (train step + train step end)# -------------------# automatic_optimization=True: perform ddp sync only when performing optimizer_stepwith _block_parallel_sync_behavior(self.trainer.strategy, block=True):closure()# ------------------------------# BACKWARD PASS# ------------------------------# gradient update with accumulated gradientselse:self._optimizer_step(batch_idx, closure)result = closure.consume_result()if result.loss is None:return {}return result.asdict()

def _make_closure

创建一个closure对象,来捕捉给定的参数并且运行’training_step’和可选的其他如backword和zero_grad函数

比较重要的是step_fn,在这里调用了_training_step,得到的是一个存储我们在定义模型时重写的training step的输出所构成ClosureResult数据类。具体见def _training_step

def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer, batch_idx: int) -> Closure:step_fn = self._make_step_fn(kwargs)backward_fn = self._make_backward_fn(optimizer)zero_grad_fn = self._make_zero_grad_fn(batch_idx, optimizer)return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn)

def _training_step

通过hook函数实现真正的训练step,返回一个存储training step输出的ClosureResult数据类。

将我们在定义模型时定义的lightning.pytorch.core.LightningModule.training_step的输出作为参数传入存储容器class ClosureResult的from_training_step_output方法,见class Closure

class ClosureResult():

一个数据类,包含closure_loss,loss,extra

    closure_loss: Optional[Tensor]loss: Optional[Tensor] = field(init=False, default=None)extra: Dict[str, Any] = field(default_factory=dict)
def from_training_step_output

一个类方法,如果我们在training_step定义的返回是一个字典,则我们会将key值为"loss"的value赋值给closure_loss,而其余的键值对赋值给extra字典,如果返回的既不是包含"loss"的字典也不是tensor,则会报错。当我们在training_step不设定返回,则自然为None,但是不会报错。

class Closure

闭包是将外部作用域中的变量绑定到对这些变量进行计算的函数变量,而不将它们明确地作为输入。这样做的好处是可以将闭包传递给对象,之后可以像函数一样调用它,但不需要传入任何参数。

在lightning,用于自动优化的Closure类将training_step和backward, zero_grad三个基本的闭包结合在一起。

这个Closure得到training循环中的结果之后传入torch.optim.Optimizer.step。

参数:

  • step_fn: 这里是一个存储了training step输出的ClosureResult数据类,见def _training_step
  • backward_fn: 梯度回传函数
  • zero_grad_fn: 梯度清零函数

按照顺序,会先检查得到loss,之后调用梯度清零函数,最后调用梯度回传函数

class Closure(AbstractClosure[ClosureResult]):warning_cache = WarningCache()def __init__(self,step_fn: Callable[[], ClosureResult],backward_fn: Optional[Callable[[Tensor], None]] = None,zero_grad_fn: Optional[Callable[[], None]] = None,):super().__init__()self._step_fn = step_fnself._backward_fn = backward_fnself._zero_grad_fn = zero_grad_fn@override@torch.enable_grad()def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:step_output = self._step_fn()if step_output.closure_loss is None:self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")if self._zero_grad_fn is not None:self._zero_grad_fn()if self._backward_fn is not None and step_output.closure_loss is not None:self._backward_fn(step_output.closure_loss)return step_output@overridedef __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:self._result = self.closure(*args, **kwargs)return self._result.loss

这篇关于PyTorch-Lightning:trining_step的自动优化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/896785

相关文章

uniapp接入微信小程序原生代码配置方案(优化版)

uniapp项目需要把微信小程序原生语法的功能代码嵌套过来,无需把原生代码转换为uniapp,可以配置拷贝的方式集成过来 1、拷贝代码包到src目录 2、vue.config.js中配置原生代码包直接拷贝到编译目录中 3、pages.json中配置分包目录,原生入口组件的路径 4、manifest.json中配置分包,使用原生组件 5、需要把原生代码包里的页面修改成组件的方

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

WordPress网创自动采集并发布插件

网创教程:WordPress插件网创自动采集并发布 阅读更新:随机添加文章的阅读数量,购买数量,喜欢数量。 使用插件注意事项 如果遇到404错误,请先检查并调整网站的伪静态设置,这是最常见的问题。需要定制化服务,请随时联系我。 本次更新内容 我们进行了多项更新和优化,主要包括: 界面设置:用户现在可以更便捷地设置文章分类和发布金额。代码优化:改进了采集和发布代码,提高了插件的稳定

服务器雪崩的应对策略之----SQL优化

SQL语句的优化是数据库性能优化的重要方面,特别是在处理大规模数据或高频访问时。作为一个C++程序员,理解SQL优化不仅有助于编写高效的数据库操作代码,还能增强对系统性能瓶颈的整体把握。以下是详细的SQL语句优化技巧和策略: SQL优化 1. 选择合适的数据类型2. 使用索引3. 优化查询4. 范式化和反范式化5. 查询重写6. 使用缓存7. 优化数据库设计8. 分析和监控9. 调整配置1、

【青龙面板辅助】JD商品自动给好评获取京豆脚本

1.打开链接 开下面的链接进入待评价商品页面 https://club.jd.com/myJdcomments/myJdcomments.action?sort=0 2.登陆后执行脚本 登陆后,按F12键,选择console,复制粘贴以下代码,先运行脚本1,再运行脚本2 脚本1代码 可以自行修改评价内容。 var content = '材质很好,质量也不错,到货也很快物流满分,包装快递满

Java中如何优化数据库查询性能?

Java中如何优化数据库查询性能? 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!今天我们将深入探讨在Java中如何优化数据库查询性能,这是提升应用程序响应速度和用户体验的关键技术。 优化数据库查询性能的重要性 在现代应用开发中,数据库查询是最常见的操作之一。随着数据量的增加和业务复杂度的提升,数据库查询的性能优化显得尤为重

PyTorch模型_trace实战:深入理解与应用

pytorch使用trace模型 1、使用trace生成torchscript模型2、使用trace的模型预测 1、使用trace生成torchscript模型 def save_trace(model, input, save_path):traced_script_model = torch.jit.trace(model, input)<

AI炒股:自动画出A股股票的K线图并添加技术指标

在deepseek中输入提示词: 你是一个Python编程专家,要完成一个编写Python脚本的任务,具体步骤如下: 用AKShare库获取股票贵州茅台(股票代码:600519)在2024年3月7日到2024年6月5日期间的历史行情数据-前复权。 然后绘制K线图,并在K线图上添加布林线、MACD 注意: 每一步都要输出信息到屏幕上; 日期格式是YYYYMMDD; 设置中文字体,以解决

XMG 自动提示宏 #define keyPath(objc,keyPath) @(((void)objc.keyPath,#keyPath));

1. int a=((void)5,4)  C语言逗号表达式默认会取右边的内容 如果不写void的话 a会被报警告,写上void标明请忽略左边的内容 插曲刚才弄得,已经上线的苹果产品如果需要下架的话,点击 价格与销售范围,然后点击下架。这个产品就会在AppStore 中移除。如果想再让改产品重新在Apple store中显示,那么再次让他上线就可以了。但是会有一定的时间延迟 /

打包体积分析和优化

webpack分析工具:webpack-bundle-analyzer 1. 通过<script src="./vue.js"></script>方式引入vue、vuex、vue-router等包(CDN) // webpack.config.jsif(process.env.NODE_ENV==='production') {module.exports = {devtool: 'none