Training - PyTorch Lightning 分布式训练的 global_step 参数 (accumulate_grad_batches)

本文主要是介绍Training - PyTorch Lightning 分布式训练的 global_step 参数 (accumulate_grad_batches),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://blog.csdn.net/caroline_wendy/article/details/137640653

PyTorch

在 PyTorch Lightning 中,pl.Traineraccumulate_grad_batches 参数允许在执行反向传播和优化器步骤之前,累积多个批次的梯度。这样,可以增加有效的批次大小,而不会增加内存开销。例如,如果设置 accumulate_grad_batches=8,则会在执行优化器的 .step() 方法之前,累积 8 个批次的梯度。

accumulate_grad_batchesglobal_step 的关系:

  1. global_step 会在每次调用优化器的 .step() 方法后递增。
  2. 使用梯度累积,global_step 增长小于 批次(batch) 的数量
  3. 多个批次贡献到 1 个 global_step 的更新中。

例如,如果 accumulate_grad_batches=8,那么每 8 个批次,只会增加 1 次 global_step,如果多卡,则 global_step 表示单卡的次数。日志,如下:

[INFO] [CL] global_step: 0, iter_step: 8
[INFO] [CL] global_step: 1, iter_step: 16

其中 pl.Trainer 的源码:

    trainer = pl.Trainer(accelerator="gpu",# ...accumulate_grad_batches=args.accumulate_grad,strategy=strategy,  # 多机多卡配置num_nodes=args.num_nodes,  # 节点数devices=1,  # 每个节点 GPU 卡数)

输出日志:

log = {'epoch': self.trainer.current_epoch, 'step': self.trainer.global_step}
wandb.log(log)

这篇关于Training - PyTorch Lightning 分布式训练的 global_step 参数 (accumulate_grad_batches)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/896286

相关文章

Eureka高可用注册中心registered-replicas没有分布式注册中心

自己在学习过程中发现,如果Eureka挂掉了,其他的Client就跑不起来了,那既然是商业项目,还是要处理好这个问题,所以决定用《Spring Cloud微服务实战》(PDF版在全栈技术交流群中自行获取)中说的“高可用注册中心”。 一开始我yml的配置是这样的 server:port: 8761eureka:instance:hostname: 127.0.0.1client:fetch-r

ABAP怎么把传入的参数刷新到内表里面呢?

1.在执行相关的功能操作之前,优先执行这一段代码,把输入的数据更新入内表里面 DATA: lo_guid TYPE REF TO cl_gui_alv_grid.CALL FUNCTION 'GET_GLOBALS_FROM_SLVC_FULLSCR'IMPORTINGe_grid = lo_guid.CALL METHOD lo_guid->check_changed_data.CALL M

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

YOLO v3 训练速度慢的问题

一天一夜出了两个模型,仅仅迭代了200次   原因:编译之前没有将Makefile 文件里的GPU设置为1,编译的是CPU版本,必须训练慢   解决方案: make clean  vim Makefile make   再次训练 速度快了,5分钟迭代了500次

Java面试八股之JVM参数-XX:+UseCompressedOops的作用

JVM参数-XX:+UseCompressedOops的作用 JVM参数-XX:+UseCompressedOops的作用是启用对象指针压缩(Ordinary Object Pointers compression)。这一特性主要应用于64位的Java虚拟机中,目的是为了减少内存使用。在传统的64位系统中,对象引用(即指针)通常占用8字节(64位),而大部分应用程序实际上并不需要如此大的地址空间

将一维机械振动信号构造为训练集和测试集(Python)

从如下链接中下载轴承数据集。 https://www.sciencedirect.com/science/article/pii/S2352340918314124 import numpy as npimport scipy.io as sioimport matplotlib.pyplot as pltimport statistics as statsimport pandas

PyTorch模型_trace实战:深入理解与应用

pytorch使用trace模型 1、使用trace生成torchscript模型2、使用trace的模型预测 1、使用trace生成torchscript模型 def save_trace(model, input, save_path):traced_script_model = torch.jit.trace(model, input)<

[分布式网络通讯框架]----Zookeeper客户端基本操作----ls、get、create、set、delete

Zookeeper数据结构 zk客户端常用命令 进入客户端 在bin目录下输入./zkCli.sh 查看根目录下数据ls / 注意:要查看哪一个节点,必须把路径写全 查看节点数据信息 get /第一行代码数据,没有的话表示没有数据 创建节点create /sl 20 /sl为节点的路径,20为节点的数据 注意,不能跨越创建,也就是说,创建sl2的时候,必须确保sl

关于命令行参数argv(《学习OpenCV》)

在《学习OpenCV》这本书中,很多示例代码都用到了命令行参数。作为新手,之前总是很困扰,不知道怎么用。偶然的机会终于略知一二了。 在Visual Studio中,我们可以自行设置命令行参数。 如在这个示例程序中,我们想把图像存入argv[1]。 方法如下: 依次点击,项目、属性、配置属性、调试、命令参数。出现下面的界面: 然后进行编辑,即输入图像路径。如:E:\Lena.jpg

[分布式网络通讯框架]----ZooKeeper下载以及Linux环境下安装与单机模式部署(附带每一步截图)

首先进入apache官网 点击中间的see all Projects->Project List菜单项进入页面 找到zookeeper,进入 在Zookeeper主页的顶部点击菜单Project->Releases,进入Zookeeper发布版本信息页面,如下图: 找到需要下载的版本 进行下载既可,这里我已经下载过3.4.10,所以以下使用3.4.10进行演示其他的步骤。