分布式训练同步梯度出现形状不一致的解决方案

2024-09-06 19:12

本文主要是介绍分布式训练同步梯度出现形状不一致的解决方案,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、问题描述

          为了加快大模型的训练速度,采用了分布式训练策略,基于MultiWorkerServerStrategy模式,集群之间采用Ring—Reduce的通信机制,不同节点在同步梯度会借助collective_ops.all_gather方法将梯度进行汇聚收集,汇聚过程出现了:

allreduce_1/CollectiveGather_1 Inconsitent output shapes,got[20],but expected is [22]

allreduce_1/CollectiveGather  Inconsitent output shapes,got[16,8],but expected is [20,8]

从而终止了训练继续进行。

2、原因分析

         直观看是因为不连续的输出形状,即要求的输出形状对于第一个是[22],却输出了[20],造成了不一致,查阅相关资料发现在tensorflow1.15早期的版本中,底层的源码文件tensorflow/core/kernels/collective_ops.cc

当col_params_.instance.shape.num_elements() == 0时表明是首次批来的时候,记住了output_shape,当第二批次或后面的批次再来的时候,强行判断和首次记住的形状保持一致,如果不一致就报错打印出了上面的“输出形状不连续的问题”,即errors::Internal里的内容。这也就是之所以报梯度形状不一致的根本原因。

3、问题解决

          分析清楚了原因,制定对应的解决办法。当然可以将该段代码的逻辑去掉,当后面批次再来的时候,不做判断,而是让col_params_.instance.shape=output_shape始终跟最新你的输出保持一致。如下所示:去掉了老版本里的if else的判断,直接让col_params_.instance.shape=out_shape,兼容输出可变化的动态形状

该种解决方案优点是从根本上解决,上层应用无感知,然而缺点是改完后要重新编译cc代码生成so文件,或者升级到最新的版本,对于不开放的网络环境,升级tensorflow或重新编译成本巨高。为了依旧使用老版本,尽可能不动底层,采用修改上层的方法,虽然繁琐一些,但是修改成本会低恨多。

        对该问题进行更深入的分析,到底为什么梯度输出的形状会发生改变,即out_shape和首次批的输出形状可能会不一致呢?仔细梳理了每个批次的梯度产生过程。以一个id类特征product_id为例,假定训练的batch_size=1024,product总个数30,embedding后总参数大小[30,8], 第一个批次输入的批次数据是[1024,8],即每次输入1024个样本,而梯度回传有时候是[20,8],有时候是[16,8],之所以和输入没有对齐,经分析发现在反向传播,通过collective_all_gather收集各个集群上的梯度时,是以特征变量为单位,不是以样本量来衡量的,即会收集每个训练的特征变量,在各个节点上的梯度,收回来做累加或其他聚合操作,在有些1024的批data中,product_id包含20个(不同样本的product_id会有重复),有些1024的批data中,product_id包含16个,这样反回来的梯度是这20个或16个product_id的emb的梯度,所以看到的梯度的形状是[20,8]和[16,8],这也体现了训练的过程更新的可训练特征变量,以不同特征变量的个数来组织梯度也顺利成章。

解决办法:当一个批次的数据大小是batch_size的时候,根据以上分析,某个特征变量的不同值的个数上限是batch_size个,因此把梯度的形状pad成[batch_size,dim],这样就就保证了每次进入collective_op.all_gather的形状保持了一致,另一个问题就是这个batch_size如何传入cross_device_utils.py,刚上来考虑通过获取tensorgrah中输入变量的第一维的值来作为batch_size,这样会有个问题就是,当最后batch的大小不够一个batch_size的时候,补的形状就和前面的又一样,还是会失败,训练最后一步挂掉;因此考虑传入固定的静态手工配置的batch_size,通过参数传递的方式,内部经过的链路很长,会进入不同的模块,才会传导到cross_device_utils.py,这种方式改动太大,自然而然想到共享内存,python的共享内存可能需要第三方的工具包,成本也高,进而考虑共享文件,启动的时候将静态固定的batch_size写入一个固定的目录文件,在cross_device_utils.py里用到的时候读取文件,这样改动的成本还是有些繁琐,最后考虑python夸文件的变量共享,在cross_device_utils.py,定义一个全局变量global_batch_size给定默认值256,在训练启动的python文件main方法里通过引用修改该变量,即:

from tensorflow.python.distribute import cross_device_utils

cross_device_utils.global_batch_size=Configs[‘batch_size’]

具体修改如下:

修改前:

修改后:

4、总结

        本文对Ring-AllReduce通信框架下分布式训练梯度收集形状不一致的问题进行了分析,并阐述了从最底层和偏上层的不同解决思路。对使用稍早版本的tf搭建分布式训练平台有一定的借鉴作用。

这篇关于分布式训练同步梯度出现形状不一致的解决方案的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1142878

相关文章

深入理解Redis大key的危害及解决方案

《深入理解Redis大key的危害及解决方案》本文主要介绍了深入理解Redis大key的危害及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着... 目录一、背景二、什么是大key三、大key评价标准四、大key 产生的原因与场景五、大key影响与危

Xshell远程连接失败以及解决方案

《Xshell远程连接失败以及解决方案》本文介绍了在Windows11家庭版和CentOS系统中解决Xshell无法连接远程服务器问题的步骤,在Windows11家庭版中,需要通过设置添加SSH功能并... 目录一.问题描述二.原因分析及解决办法2.1添加ssh功能2.2 在Windows中开启ssh服务2

Redis连接失败:客户端IP不在白名单中的问题分析与解决方案

《Redis连接失败:客户端IP不在白名单中的问题分析与解决方案》在现代分布式系统中,Redis作为一种高性能的内存数据库,被广泛应用于缓存、消息队列、会话存储等场景,然而,在实际使用过程中,我们可能... 目录一、问题背景二、错误分析1. 错误信息解读2. 根本原因三、解决方案1. 将客户端IP添加到Re

java如何分布式锁实现和选型

《java如何分布式锁实现和选型》文章介绍了分布式锁的重要性以及在分布式系统中常见的问题和需求,它详细阐述了如何使用分布式锁来确保数据的一致性和系统的高可用性,文章还提供了基于数据库、Redis和Zo... 目录引言:分布式锁的重要性与分布式系统中的常见问题和需求分布式锁的重要性分布式系统中常见的问题和需求

详谈redis跟数据库的数据同步问题

《详谈redis跟数据库的数据同步问题》文章讨论了在Redis和数据库数据一致性问题上的解决方案,主要比较了先更新Redis缓存再更新数据库和先更新数据库再更新Redis缓存两种方案,文章指出,删除R... 目录一、Redis 数据库数据一致性的解决方案1.1、更新Redis缓存、删除Redis缓存的区别二

python 字典d[k]中key不存在的解决方案

《python字典d[k]中key不存在的解决方案》本文主要介绍了在Python中处理字典键不存在时获取默认值的两种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录defaultdict:处理找不到的键的一个选择特殊方法__missing__有时候为了方便起见,

Golang使用etcd构建分布式锁的示例分享

《Golang使用etcd构建分布式锁的示例分享》在本教程中,我们将学习如何使用Go和etcd构建分布式锁系统,分布式锁系统对于管理对分布式系统中共享资源的并发访问至关重要,它有助于维护一致性,防止竞... 目录引言环境准备新建Go项目实现加锁和解锁功能测试分布式锁重构实现失败重试总结引言我们将使用Go作

Redis分布式锁使用及说明

《Redis分布式锁使用及说明》本文总结了Redis和Zookeeper在高可用性和高一致性场景下的应用,并详细介绍了Redis的分布式锁实现方式,包括使用Lua脚本和续期机制,最后,提到了RedLo... 目录Redis分布式锁加锁方式怎么会解错锁?举个小案例吧解锁方式续期总结Redis分布式锁如果追求

Linux限制ip访问的解决方案

《Linux限制ip访问的解决方案》为了修复安全扫描中发现的漏洞,我们需要对某些服务设置访问限制,具体来说,就是要确保只有指定的内部IP地址能够访问这些服务,所以本文给大家介绍了Linux限制ip访问... 目录背景:解决方案:使用Firewalld防火墙规则验证方法深度了解防火墙逻辑应用场景与扩展背景:

SpringBoot嵌套事务详解及失效解决方案

《SpringBoot嵌套事务详解及失效解决方案》在复杂的业务场景中,嵌套事务可以帮助我们更加精细地控制数据的一致性,然而,在SpringBoot中,如果嵌套事务的配置不当,可能会导致事务不生效的问题... 目录什么是嵌套事务?嵌套事务失效的原因核心问题:嵌套事务的解决方案方案一:将嵌套事务方法提取到独立类