大模型微调中的内存效率问题及解决方案

2024-09-02 14:04

本文主要是介绍大模型微调中的内存效率问题及解决方案,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

人工智能咨询培训老师叶梓 转载标明出处

大模型(LLMs)在大规模训练中的内存消耗问题日益凸显,传统的参数高效微调技术,如低秩适应(LoRA),虽然在一定程度上缓解了这一问题,但其性能在很多大规模微调场景下仍无法与全参数训练相媲美。

为了解决上述问题,香港科技大学以及伊利诺伊大学香槟分校的研究团队共同提出了一种新的训练策略——Layerwise Importance Sampled AdamW(LISA)。LISA策略基于对LoRA在微调任务中层级权重规范分布的观察,发现不同层的权重规范呈现出不寻常的偏斜分布。利用这一关键发现,研究者们提出了一种简单有效的训练方法,该方法在多种设置下的性能都超过了LoRA和全参数训练,同时内存成本与LoRA相当。图1为在Alpaca GPT-4数据集上,使用全参数训练(FT)、LoRA和LISA方法对LLaMA-2-7B模型进行训练时的损失变化情况。显示了LISA方法相比其他方法在训练损失上的优势。

论文链接:https://arxiv.org/abs/2403.17919

开源地址:https://github.com/OptimalScale/LMFlow

方法

为了理解LoRA如何仅用少量参数实现有效训练,研究者们对多个模型进行了实证研究,特别关注了不同层的权重规范。他们使用Alpaca-GPT4数据集进行微调,并在训练过程中详细记录了每一层ℓ在每次更新后的平均权重规范,其公式表示为:

其中表示层ℓ 的平均权重规范。

实验发现在LoRA训练中,嵌入层或语言模型(LM)头部层的权重规范显著大于中间层,有时甚至高出数百倍。然而,在全参数训练设置下,这种现象并不明显。

Figure 2 展示了GPT2和LLaMA-2-7B模型在LoRA和全参数训练期间的层级权重规范。图中x轴代表从嵌入权重到最终层的层级,y轴量化了权重规范。这一可视化揭示了一个关键趋势:嵌入层或LM头部层在LoRA中的权重规范远大于中间层。

基于上述发现,研究者们希望模拟LoRA的更新模式,通过采样不同的层进行冻结,以避免LoRA固有的低秩表示能力的限制,并模仿其快速学习过程。在全参数设置中,LoRA中权重规范较小的层也应该有较小的采样概率来解冻,以保持迭代中的预期学习率相同。这正是重要性采样的思想。

Algorithm 1 展示了LISA方法的步骤。在实践中,除了底部和顶部层外,LoRA中所有层的权重规范都较小,因此研究者们采用​=,其中γ控制优化过程中预期的解冻层数。γ作为一个补偿因子,用来桥接LoRA和全参数调优之间的差异,让LISA模拟与LoRA相似的层级更新模式。为了进一步控制实际设置中的内存消耗,研究者们每次随机采样γ层,以限制训练期间最大未冻结层数。

通过这种方法,LISA算法能够在保持内存效率的同时,提高大型语言模型微调的性能。这一创新方法为解决LoRA在大规模微调中的局限性提供了新的思路,并展示了在不同领域任务中应用的潜力。

想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具。9月22日晚,实战专家1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。

加助理微信提供直播链接:amliy007,29.9元即可参加线上直播分享,叶老师亲自指导,互动沟通,全面掌握Llama Factory,关注享粉丝福利,限时免费CSDN听直播后的录播讲解。
 

LLaMA Factory 支持多种预训练模型和微调算法。它提供灵活的运算精度和优化算法选择,以及丰富的实验监控工具。开源特性和社区支持使其易于使用,适合各类用户快速提升模型性能。

实验

为了证明LISA的内存效率,研究者们进行了峰值GPU内存消耗的实验。实验设置中,他们从Alpaca数据集(Taori et al., 2023)中随机抽取提示,并限制最大输出令牌长度为1024。重点关注两个关键超参数:LoRA的秩和LISA的激活层数。在其他超参数设置中,所有模型统一使用1的mini-batch大小,并排除了其他节省GPU内存的技术,如梯度检查点(Chen et al., 2016)、卸载(Ren et al., 2021)和快速注意力(Dao et al., 2022; Dao, 2023)。

Table 1为不同模型架构和配置下的峰值GPU内存消耗。特别是,当LISA配置增强了嵌入层(E)和两个额外层(E+H+2L)时,在微调LLaMA-2-70B模型时,与LoRA方法相比,显示出了相当大的GPU内存使用减少。具体而言LISA E+H+2L配置将峰值GPU内存从LoRA Rank 128配置所需的79G降低到75G。这种效率提升不是孤立的事件;在不同模型架构上观察到系统性的内存使用减少,表明LISA激活层的方法在内存效率上具有固有优势。

Figure 3 展示了不同方法和批量大小为1的LLaMA2-7B的GPU内存消耗。注意,LISA的内存减少允许LLaMA-2-7B在单个RTX4090(24GB)GPU上进行训练,这使得即使在笔记本电脑上也能负担得起高质量的微调。特别是由于LISA不引入适配器带来的额外参数,因此其激活内存消耗比LoRA少得多。由于pytorch(Paszke et al., 2019)与deepspeed(Rasley et al., 2020)允许在反向传播前删除冗余激活,LISA的激活内存甚至略低于全参数训练。

Figure 4 展示了不同方法和批量大小为1的LLaMA-2-7B的单次迭代时间成本。LISA由于减少了内存占用,还带来了加速效果。如图4所示,与全参数训练相比,LISA提供了大约2.9倍的加速,与LoRA相比大约有1.5倍的加速,这部分是由于去除了适配器结构。值得注意的是,LoRA和LISA的内存占用减少都显著加快了前向传播的速度,强调了内存高效训练的重要性。

LISA在保持显著内存节省的同时,还能在微调设置中获得有竞争力的性能。为了证明LISA优于LoRA,研究者们在Alpaca GPT-4数据集(Taori et al., 2023)上评估了它们的性能,该数据集包含由GPT-4(OpenAI et al., 2023)生成的52k对对话。微调的有效性在多个基准上进行评估:MT-Bench(Zheng et al., 2023)包含80个高质量的多轮问题,旨在从多个方面评估LLMs;MMLU(Hendrycks et al., 2020)总共包含57个任务,14,079个问题,涵盖广泛的世界知识;AGIEval(Zhong et al., 2023)作为以人为本的通用能力基准,包含9,316个实例;WinoGrande(Sakaguchi et al., 2021)是大规模常识推理数据集,包含44,000个实例,旨在挑战模型对上下文和常识知识的了解。

Table 2 和 Table 3 展示了中等规模LLMs的详细比较。基线包括全参数训练(FT)、低秩适应(LoRA)(Hu et al., 2022)和梯度低秩投影(GaLore)(Zhao et al., 2024)。结果表明,LISA在大多数评估轨道上一致性地优于其他微调方法,表明其在多样化任务和模型架构中的鲁棒性和有效性。LISA特别适用于指令跟随任务,在与其他基线方法相比时观察到较大的差距。LISA甚至超越了全参数训练,这表明当限制未冻结层数时,存在隐式正则化效果,类似于dropout(Srivastava et al., 2014)。

持续预训练对于使模型适应新数据和领域至关重要。为了评估LISA在持续预训练场景中的有效性,研究者们在数学领域与全参数训练进行了比较。

研究者们采用数学语料库OpenWebMath(Paster et al., 2023)构建持续预训练数据集。具体来说,他们从中提取了一个包含15亿令牌的高质量子集。详细情况在附录B.2中解释。在持续预训练后,然后对GSM8K(Cobbe et al., 2021)训练集进行相同的微调程序,该训练集包含7473个实例。

Table 4 显示,LISA能够实现与全参数训练相当甚至更好的性能,同时内存消耗要少得多。具体来说,与全参数训练相比,LISA只需要一半的内存成本。这表明LISA在计算效率和模型性能之间实现了更好的平衡。根据研究者的经验,将未冻结层数减少到原始大小的一半,在持续预训练期间不会变差甚至表现更好,同时内存消耗要少得多。

为了进一步证明LISA在大规模LLMs上的可扩展性,研究者们在LLaMA-2-70B(Touvron et al., 2023b)上进行了额外的微调实验。除了前面提到的指令跟随任务外,研究者们还使用了额外的特定领域微调任务,包括数学和医学QA基准。GSM8K数据集(Cobbe et al., 2021)包含7473个训练实例和1319个测试实例,用于数学领域。对于医学领域,研究者们选择了PubMedQA数据集(Jin et al., 2019),该数据集包含211.3K个人工生成的QA训练实例和1K个测试实例。

Table 5 显示,LISA在与LoRA相比时一致性地产生更好或相当的性能。此外,LISA在指令调整任务中再次超越了全参数训练,为LISA在大规模训练场景下的可扩展性提供了有力证据。

LISA的两个关键超参数是采样层数γ和采样周期K。为了直观和实证地指导这些超参数的选择,研究者们使用TinyLlama(Zhang et al., 2024)和LLaMA-2-7B(Touvron et al., 2023b)模型,在Alpaca-GPT4数据集上进行了消融研究。γ的配置,如E+H+2L、E+H+8L,分别表示为γ = 2和γ = 8。至于采样周期K = T /n,T = 122代表实验框架内的最大训练步骤。Table 6 中的发现表明,γ和K都显著影响LISA算法的性能。具体为较高的γ值增加了可训练参数的数量,尽管内存成本更高。另一方面,最优的K值促进了更频繁的层切换,从而在一定阈值内提高了性能,超出该阈值后性能可能会恶化。通常的经验法则是:更多的采样层和更高的采样周期会带来更好的性能。

由于LISA在算法上依赖于层的采样序列,研究者们进一步研究了LISA在三个不同运行中性能的变化,每个运行都使用不同的随机种子进行层选择。研究者们采用TinyLlama、LLaMA2-7B和Mistral-7B模型与Alpaca-GPT4数据集,同时保持所有其他超参数与前面指令跟随实验中使用的一致。Table 7 显示,LISA对不同的随机种子相当有韧性,三次运行之间的性能差距在0.13以内,与超过基线方法的性能增益相比,这是一个小值。

实验结果显示,LISA在保持相似或更低的GPU内存消耗的同时,在下游微调任务中的性能超越了LoRA,甚至在某些情况下还超越了全参数训练。

这篇关于大模型微调中的内存效率问题及解决方案的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1130164

相关文章

SpringBoot启动报错的11个高频问题排查与解决终极指南

《SpringBoot启动报错的11个高频问题排查与解决终极指南》这篇文章主要为大家详细介绍了SpringBoot启动报错的11个高频问题的排查与解决,文中的示例代码讲解详细,感兴趣的小伙伴可以了解一... 目录1. 依赖冲突:NoSuchMethodError 的终极解法2. Bean注入失败:No qu

找不到Anaconda prompt终端的原因分析及解决方案

《找不到Anacondaprompt终端的原因分析及解决方案》因为anaconda还没有初始化,在安装anaconda的过程中,有一行是否要添加anaconda到菜单目录中,由于没有勾选,导致没有菜... 目录问题原因问http://www.chinasem.cn题解决安装了 Anaconda 却找不到 An

Spring定时任务只执行一次的原因分析与解决方案

《Spring定时任务只执行一次的原因分析与解决方案》在使用Spring的@Scheduled定时任务时,你是否遇到过任务只执行一次,后续不再触发的情况?这种情况可能由多种原因导致,如未启用调度、线程... 目录1. 问题背景2. Spring定时任务的基本用法3. 为什么定时任务只执行一次?3.1 未启用

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

MySQL新增字段后Java实体未更新的潜在问题与解决方案

《MySQL新增字段后Java实体未更新的潜在问题与解决方案》在Java+MySQL的开发中,我们通常使用ORM框架来映射数据库表与Java对象,但有时候,数据库表结构变更(如新增字段)后,开发人员可... 目录引言1. 问题背景:数据库与 Java 实体不同步1.1 常见场景1.2 示例代码2. 不同操作

如何解决mysql出现Incorrect string value for column ‘表项‘ at row 1错误问题

《如何解决mysql出现Incorrectstringvalueforcolumn‘表项‘atrow1错误问题》:本文主要介绍如何解决mysql出现Incorrectstringv... 目录mysql出现Incorrect string value for column ‘表项‘ at row 1错误报错

如何解决Spring MVC中响应乱码问题

《如何解决SpringMVC中响应乱码问题》:本文主要介绍如何解决SpringMVC中响应乱码问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Spring MVC最新响应中乱码解决方式以前的解决办法这是比较通用的一种方法总结Spring MVC最新响应中乱码解

java常见报错及解决方案总结

《java常见报错及解决方案总结》:本文主要介绍Java编程中常见错误类型及示例,包括语法错误、空指针异常、数组下标越界、类型转换异常、文件未找到异常、除以零异常、非法线程操作异常、方法未定义异常... 目录1. 语法错误 (Syntax Errors)示例 1:解决方案:2. 空指针异常 (NullPoi

pip无法安装osgeo失败的问题解决

《pip无法安装osgeo失败的问题解决》本文主要介绍了pip无法安装osgeo失败的问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 进入官方提供的扩展包下载网站寻找版本适配的whl文件注意:要选择cp(python版本)和你py

使用DrissionPage控制360浏览器的完美解决方案

《使用DrissionPage控制360浏览器的完美解决方案》在网页自动化领域,经常遇到需要保持登录状态、保留Cookie等场景,今天要分享的方案可以完美解决这个问题:使用DrissionPage直接... 目录完整代码引言为什么要使用已有用户数据?核心代码实现1. 导入必要模块2. 关键配置(重点!)3.