分布式训练同步梯度出现形状不一致的解决方案

2024-09-06 19:12

本文主要是介绍分布式训练同步梯度出现形状不一致的解决方案,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、问题描述

          为了加快大模型的训练速度,采用了分布式训练策略,基于MultiWorkerServerStrategy模式,集群之间采用Ring—Reduce的通信机制,不同节点在同步梯度会借助collective_ops.all_gather方法将梯度进行汇聚收集,汇聚过程出现了:

allreduce_1/CollectiveGather_1 Inconsitent output shapes,got[20],but expected is [22]

allreduce_1/CollectiveGather  Inconsitent output shapes,got[16,8],but expected is [20,8]

从而终止了训练继续进行。

2、原因分析

         直观看是因为不连续的输出形状,即要求的输出形状对于第一个是[22],却输出了[20],造成了不一致,查阅相关资料发现在tensorflow1.15早期的版本中,底层的源码文件tensorflow/core/kernels/collective_ops.cc

当col_params_.instance.shape.num_elements() == 0时表明是首次批来的时候,记住了output_shape,当第二批次或后面的批次再来的时候,强行判断和首次记住的形状保持一致,如果不一致就报错打印出了上面的“输出形状不连续的问题”,即errors::Internal里的内容。这也就是之所以报梯度形状不一致的根本原因。

3、问题解决

          分析清楚了原因,制定对应的解决办法。当然可以将该段代码的逻辑去掉,当后面批次再来的时候,不做判断,而是让col_params_.instance.shape=output_shape始终跟最新你的输出保持一致。如下所示:去掉了老版本里的if else的判断,直接让col_params_.instance.shape=out_shape,兼容输出可变化的动态形状

该种解决方案优点是从根本上解决,上层应用无感知,然而缺点是改完后要重新编译cc代码生成so文件,或者升级到最新的版本,对于不开放的网络环境,升级tensorflow或重新编译成本巨高。为了依旧使用老版本,尽可能不动底层,采用修改上层的方法,虽然繁琐一些,但是修改成本会低恨多。

        对该问题进行更深入的分析,到底为什么梯度输出的形状会发生改变,即out_shape和首次批的输出形状可能会不一致呢?仔细梳理了每个批次的梯度产生过程。以一个id类特征product_id为例,假定训练的batch_size=1024,product总个数30,embedding后总参数大小[30,8], 第一个批次输入的批次数据是[1024,8],即每次输入1024个样本,而梯度回传有时候是[20,8],有时候是[16,8],之所以和输入没有对齐,经分析发现在反向传播,通过collective_all_gather收集各个集群上的梯度时,是以特征变量为单位,不是以样本量来衡量的,即会收集每个训练的特征变量,在各个节点上的梯度,收回来做累加或其他聚合操作,在有些1024的批data中,product_id包含20个(不同样本的product_id会有重复),有些1024的批data中,product_id包含16个,这样反回来的梯度是这20个或16个product_id的emb的梯度,所以看到的梯度的形状是[20,8]和[16,8],这也体现了训练的过程更新的可训练特征变量,以不同特征变量的个数来组织梯度也顺利成章。

解决办法:当一个批次的数据大小是batch_size的时候,根据以上分析,某个特征变量的不同值的个数上限是batch_size个,因此把梯度的形状pad成[batch_size,dim],这样就就保证了每次进入collective_op.all_gather的形状保持了一致,另一个问题就是这个batch_size如何传入cross_device_utils.py,刚上来考虑通过获取tensorgrah中输入变量的第一维的值来作为batch_size,这样会有个问题就是,当最后batch的大小不够一个batch_size的时候,补的形状就和前面的又一样,还是会失败,训练最后一步挂掉;因此考虑传入固定的静态手工配置的batch_size,通过参数传递的方式,内部经过的链路很长,会进入不同的模块,才会传导到cross_device_utils.py,这种方式改动太大,自然而然想到共享内存,python的共享内存可能需要第三方的工具包,成本也高,进而考虑共享文件,启动的时候将静态固定的batch_size写入一个固定的目录文件,在cross_device_utils.py里用到的时候读取文件,这样改动的成本还是有些繁琐,最后考虑python夸文件的变量共享,在cross_device_utils.py,定义一个全局变量global_batch_size给定默认值256,在训练启动的python文件main方法里通过引用修改该变量,即:

from tensorflow.python.distribute import cross_device_utils

cross_device_utils.global_batch_size=Configs[‘batch_size’]

具体修改如下:

修改前:

修改后:

4、总结

        本文对Ring-AllReduce通信框架下分布式训练梯度收集形状不一致的问题进行了分析,并阐述了从最底层和偏上层的不同解决思路。对使用稍早版本的tf搭建分布式训练平台有一定的借鉴作用。

这篇关于分布式训练同步梯度出现形状不一致的解决方案的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1142878

相关文章

线上Java OOM问题定位与解决方案超详细解析

《线上JavaOOM问题定位与解决方案超详细解析》OOM是JVM抛出的错误,表示内存分配失败,:本文主要介绍线上JavaOOM问题定位与解决方案的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录一、OOM问题核心认知1.1 OOM定义与技术定位1.2 OOM常见类型及技术特征二、OOM问题定位工具

Python一次性将指定版本所有包上传PyPI镜像解决方案

《Python一次性将指定版本所有包上传PyPI镜像解决方案》本文主要介绍了一个安全、完整、可离线部署的解决方案,用于一次性准备指定Python版本的所有包,然后导出到内网环境,感兴趣的小伙伴可以跟随... 目录为什么需要这个方案完整解决方案1. 项目目录结构2. 创建智能下载脚本3. 创建包清单生成脚本4

java.sql.SQLTransientConnectionException连接超时异常原因及解决方案

《java.sql.SQLTransientConnectionException连接超时异常原因及解决方案》:本文主要介绍java.sql.SQLTransientConnectionExcep... 目录一、引言二、异常信息分析三、可能的原因3.1 连接池配置不合理3.2 数据库负载过高3.3 连接泄漏

Python与MySQL实现数据库实时同步的详细步骤

《Python与MySQL实现数据库实时同步的详细步骤》在日常开发中,数据同步是一项常见的需求,本篇文章将使用Python和MySQL来实现数据库实时同步,我们将围绕数据变更捕获、数据处理和数据写入这... 目录前言摘要概述:数据同步方案1. 基本思路2. mysql Binlog 简介实现步骤与代码示例1

C#文件复制异常:"未能找到文件"的解决方案与预防措施

《C#文件复制异常:未能找到文件的解决方案与预防措施》在C#开发中,文件操作是基础中的基础,但有时最基础的File.Copy()方法也会抛出令人困惑的异常,当targetFilePath设置为D:2... 目录一个看似简单的文件操作问题问题重现与错误分析错误代码示例错误信息根本原因分析全面解决方案1. 确保

C# LiteDB处理时间序列数据的高性能解决方案

《C#LiteDB处理时间序列数据的高性能解决方案》LiteDB作为.NET生态下的轻量级嵌入式NoSQL数据库,一直是时间序列处理的优选方案,本文将为大家大家简单介绍一下LiteDB处理时间序列数... 目录为什么选择LiteDB处理时间序列数据第一章:LiteDB时间序列数据模型设计1.1 核心设计原则

Redis实现分布式锁全过程

《Redis实现分布式锁全过程》文章介绍Redis实现分布式锁的方法,包括使用SETNX和EXPIRE命令确保互斥性与防死锁,Redisson客户端提供的便捷接口,以及Redlock算法通过多节点共识... 目录Redis实现分布式锁1. 分布式锁的基本原理2. 使用 Redis 实现分布式锁2.1 获取锁

SpringBoot3匹配Mybatis3的错误与解决方案

《SpringBoot3匹配Mybatis3的错误与解决方案》文章指出SpringBoot3与MyBatis3兼容性问题,因未更新MyBatis-Plus依赖至SpringBoot3专用坐标,导致类冲... 目录SpringBoot3匹配MyBATis3的错误与解决mybatis在SpringBoot3如果

C++ vector越界问题的完整解决方案

《C++vector越界问题的完整解决方案》在C++开发中,std::vector作为最常用的动态数组容器,其便捷性与性能优势使其成为处理可变长度数据的首选,然而,数组越界访问始终是威胁程序稳定性的... 目录引言一、vector越界的底层原理与危害1.1 越界访问的本质原因1.2 越界访问的实际危害二、基

Python 字符串裁切与提取全面且实用的解决方案

《Python字符串裁切与提取全面且实用的解决方案》本文梳理了Python字符串处理方法,涵盖基础切片、split/partition分割、正则匹配及结构化数据解析(如BeautifulSoup、j... 目录python 字符串裁切与提取的完整指南 基础切片方法1. 使用切片操作符[start:end]2