大模型微调中的内存效率问题及解决方案

2024-09-02 14:04

本文主要是介绍大模型微调中的内存效率问题及解决方案,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

人工智能咨询培训老师叶梓 转载标明出处

大模型(LLMs)在大规模训练中的内存消耗问题日益凸显,传统的参数高效微调技术,如低秩适应(LoRA),虽然在一定程度上缓解了这一问题,但其性能在很多大规模微调场景下仍无法与全参数训练相媲美。

为了解决上述问题,香港科技大学以及伊利诺伊大学香槟分校的研究团队共同提出了一种新的训练策略——Layerwise Importance Sampled AdamW(LISA)。LISA策略基于对LoRA在微调任务中层级权重规范分布的观察,发现不同层的权重规范呈现出不寻常的偏斜分布。利用这一关键发现,研究者们提出了一种简单有效的训练方法,该方法在多种设置下的性能都超过了LoRA和全参数训练,同时内存成本与LoRA相当。图1为在Alpaca GPT-4数据集上,使用全参数训练(FT)、LoRA和LISA方法对LLaMA-2-7B模型进行训练时的损失变化情况。显示了LISA方法相比其他方法在训练损失上的优势。

论文链接:https://arxiv.org/abs/2403.17919

开源地址:https://github.com/OptimalScale/LMFlow

方法

为了理解LoRA如何仅用少量参数实现有效训练,研究者们对多个模型进行了实证研究,特别关注了不同层的权重规范。他们使用Alpaca-GPT4数据集进行微调,并在训练过程中详细记录了每一层ℓ在每次更新后的平均权重规范,其公式表示为:

其中表示层ℓ 的平均权重规范。

实验发现在LoRA训练中,嵌入层或语言模型(LM)头部层的权重规范显著大于中间层,有时甚至高出数百倍。然而,在全参数训练设置下,这种现象并不明显。

Figure 2 展示了GPT2和LLaMA-2-7B模型在LoRA和全参数训练期间的层级权重规范。图中x轴代表从嵌入权重到最终层的层级,y轴量化了权重规范。这一可视化揭示了一个关键趋势:嵌入层或LM头部层在LoRA中的权重规范远大于中间层。

基于上述发现,研究者们希望模拟LoRA的更新模式,通过采样不同的层进行冻结,以避免LoRA固有的低秩表示能力的限制,并模仿其快速学习过程。在全参数设置中,LoRA中权重规范较小的层也应该有较小的采样概率来解冻,以保持迭代中的预期学习率相同。这正是重要性采样的思想。

Algorithm 1 展示了LISA方法的步骤。在实践中,除了底部和顶部层外,LoRA中所有层的权重规范都较小,因此研究者们采用​=,其中γ控制优化过程中预期的解冻层数。γ作为一个补偿因子,用来桥接LoRA和全参数调优之间的差异,让LISA模拟与LoRA相似的层级更新模式。为了进一步控制实际设置中的内存消耗,研究者们每次随机采样γ层,以限制训练期间最大未冻结层数。

通过这种方法,LISA算法能够在保持内存效率的同时,提高大型语言模型微调的性能。这一创新方法为解决LoRA在大规模微调中的局限性提供了新的思路,并展示了在不同领域任务中应用的潜力。

想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具。9月22日晚,实战专家1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。

加助理微信提供直播链接:amliy007,29.9元即可参加线上直播分享,叶老师亲自指导,互动沟通,全面掌握Llama Factory,关注享粉丝福利,限时免费CSDN听直播后的录播讲解。
 

LLaMA Factory 支持多种预训练模型和微调算法。它提供灵活的运算精度和优化算法选择,以及丰富的实验监控工具。开源特性和社区支持使其易于使用,适合各类用户快速提升模型性能。

实验

为了证明LISA的内存效率,研究者们进行了峰值GPU内存消耗的实验。实验设置中,他们从Alpaca数据集(Taori et al., 2023)中随机抽取提示,并限制最大输出令牌长度为1024。重点关注两个关键超参数:LoRA的秩和LISA的激活层数。在其他超参数设置中,所有模型统一使用1的mini-batch大小,并排除了其他节省GPU内存的技术,如梯度检查点(Chen et al., 2016)、卸载(Ren et al., 2021)和快速注意力(Dao et al., 2022; Dao, 2023)。

Table 1为不同模型架构和配置下的峰值GPU内存消耗。特别是,当LISA配置增强了嵌入层(E)和两个额外层(E+H+2L)时,在微调LLaMA-2-70B模型时,与LoRA方法相比,显示出了相当大的GPU内存使用减少。具体而言LISA E+H+2L配置将峰值GPU内存从LoRA Rank 128配置所需的79G降低到75G。这种效率提升不是孤立的事件;在不同模型架构上观察到系统性的内存使用减少,表明LISA激活层的方法在内存效率上具有固有优势。

Figure 3 展示了不同方法和批量大小为1的LLaMA2-7B的GPU内存消耗。注意,LISA的内存减少允许LLaMA-2-7B在单个RTX4090(24GB)GPU上进行训练,这使得即使在笔记本电脑上也能负担得起高质量的微调。特别是由于LISA不引入适配器带来的额外参数,因此其激活内存消耗比LoRA少得多。由于pytorch(Paszke et al., 2019)与deepspeed(Rasley et al., 2020)允许在反向传播前删除冗余激活,LISA的激活内存甚至略低于全参数训练。

Figure 4 展示了不同方法和批量大小为1的LLaMA-2-7B的单次迭代时间成本。LISA由于减少了内存占用,还带来了加速效果。如图4所示,与全参数训练相比,LISA提供了大约2.9倍的加速,与LoRA相比大约有1.5倍的加速,这部分是由于去除了适配器结构。值得注意的是,LoRA和LISA的内存占用减少都显著加快了前向传播的速度,强调了内存高效训练的重要性。

LISA在保持显著内存节省的同时,还能在微调设置中获得有竞争力的性能。为了证明LISA优于LoRA,研究者们在Alpaca GPT-4数据集(Taori et al., 2023)上评估了它们的性能,该数据集包含由GPT-4(OpenAI et al., 2023)生成的52k对对话。微调的有效性在多个基准上进行评估:MT-Bench(Zheng et al., 2023)包含80个高质量的多轮问题,旨在从多个方面评估LLMs;MMLU(Hendrycks et al., 2020)总共包含57个任务,14,079个问题,涵盖广泛的世界知识;AGIEval(Zhong et al., 2023)作为以人为本的通用能力基准,包含9,316个实例;WinoGrande(Sakaguchi et al., 2021)是大规模常识推理数据集,包含44,000个实例,旨在挑战模型对上下文和常识知识的了解。

Table 2 和 Table 3 展示了中等规模LLMs的详细比较。基线包括全参数训练(FT)、低秩适应(LoRA)(Hu et al., 2022)和梯度低秩投影(GaLore)(Zhao et al., 2024)。结果表明,LISA在大多数评估轨道上一致性地优于其他微调方法,表明其在多样化任务和模型架构中的鲁棒性和有效性。LISA特别适用于指令跟随任务,在与其他基线方法相比时观察到较大的差距。LISA甚至超越了全参数训练,这表明当限制未冻结层数时,存在隐式正则化效果,类似于dropout(Srivastava et al., 2014)。

持续预训练对于使模型适应新数据和领域至关重要。为了评估LISA在持续预训练场景中的有效性,研究者们在数学领域与全参数训练进行了比较。

研究者们采用数学语料库OpenWebMath(Paster et al., 2023)构建持续预训练数据集。具体来说,他们从中提取了一个包含15亿令牌的高质量子集。详细情况在附录B.2中解释。在持续预训练后,然后对GSM8K(Cobbe et al., 2021)训练集进行相同的微调程序,该训练集包含7473个实例。

Table 4 显示,LISA能够实现与全参数训练相当甚至更好的性能,同时内存消耗要少得多。具体来说,与全参数训练相比,LISA只需要一半的内存成本。这表明LISA在计算效率和模型性能之间实现了更好的平衡。根据研究者的经验,将未冻结层数减少到原始大小的一半,在持续预训练期间不会变差甚至表现更好,同时内存消耗要少得多。

为了进一步证明LISA在大规模LLMs上的可扩展性,研究者们在LLaMA-2-70B(Touvron et al., 2023b)上进行了额外的微调实验。除了前面提到的指令跟随任务外,研究者们还使用了额外的特定领域微调任务,包括数学和医学QA基准。GSM8K数据集(Cobbe et al., 2021)包含7473个训练实例和1319个测试实例,用于数学领域。对于医学领域,研究者们选择了PubMedQA数据集(Jin et al., 2019),该数据集包含211.3K个人工生成的QA训练实例和1K个测试实例。

Table 5 显示,LISA在与LoRA相比时一致性地产生更好或相当的性能。此外,LISA在指令调整任务中再次超越了全参数训练,为LISA在大规模训练场景下的可扩展性提供了有力证据。

LISA的两个关键超参数是采样层数γ和采样周期K。为了直观和实证地指导这些超参数的选择,研究者们使用TinyLlama(Zhang et al., 2024)和LLaMA-2-7B(Touvron et al., 2023b)模型,在Alpaca-GPT4数据集上进行了消融研究。γ的配置,如E+H+2L、E+H+8L,分别表示为γ = 2和γ = 8。至于采样周期K = T /n,T = 122代表实验框架内的最大训练步骤。Table 6 中的发现表明,γ和K都显著影响LISA算法的性能。具体为较高的γ值增加了可训练参数的数量,尽管内存成本更高。另一方面,最优的K值促进了更频繁的层切换,从而在一定阈值内提高了性能,超出该阈值后性能可能会恶化。通常的经验法则是:更多的采样层和更高的采样周期会带来更好的性能。

由于LISA在算法上依赖于层的采样序列,研究者们进一步研究了LISA在三个不同运行中性能的变化,每个运行都使用不同的随机种子进行层选择。研究者们采用TinyLlama、LLaMA2-7B和Mistral-7B模型与Alpaca-GPT4数据集,同时保持所有其他超参数与前面指令跟随实验中使用的一致。Table 7 显示,LISA对不同的随机种子相当有韧性,三次运行之间的性能差距在0.13以内,与超过基线方法的性能增益相比,这是一个小值。

实验结果显示,LISA在保持相似或更低的GPU内存消耗的同时,在下游微调任务中的性能超越了LoRA,甚至在某些情况下还超越了全参数训练。

这篇关于大模型微调中的内存效率问题及解决方案的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1130164

相关文章

MybatisGenerator文件生成不出对应文件的问题

《MybatisGenerator文件生成不出对应文件的问题》本文介绍了使用MybatisGenerator生成文件时遇到的问题及解决方法,主要步骤包括检查目标表是否存在、是否能连接到数据库、配置生成... 目录MyBATisGenerator 文件生成不出对应文件先在项目结构里引入“targetProje

C#使用HttpClient进行Post请求出现超时问题的解决及优化

《C#使用HttpClient进行Post请求出现超时问题的解决及优化》最近我的控制台程序发现有时候总是出现请求超时等问题,通常好几分钟最多只有3-4个请求,在使用apipost发现并发10个5分钟也... 目录优化结论单例HttpClient连接池耗尽和并发并发异步最终优化后优化结论我直接上优化结论吧,

Java内存泄漏问题的排查、优化与最佳实践

《Java内存泄漏问题的排查、优化与最佳实践》在Java开发中,内存泄漏是一个常见且令人头疼的问题,内存泄漏指的是程序在运行过程中,已经不再使用的对象没有被及时释放,从而导致内存占用不断增加,最终... 目录引言1. 什么是内存泄漏?常见的内存泄漏情况2. 如何排查 Java 中的内存泄漏?2.1 使用 J

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

C#使用yield关键字实现提升迭代性能与效率

《C#使用yield关键字实现提升迭代性能与效率》yield关键字在C#中简化了数据迭代的方式,实现了按需生成数据,自动维护迭代状态,本文主要来聊聊如何使用yield关键字实现提升迭代性能与效率,感兴... 目录前言传统迭代和yield迭代方式对比yield延迟加载按需获取数据yield break显式示迭

numpy求解线性代数相关问题

《numpy求解线性代数相关问题》本文主要介绍了numpy求解线性代数相关问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 在numpy中有numpy.array类型和numpy.mat类型,前者是数组类型,后者是矩阵类型。数组

解决systemctl reload nginx重启Nginx服务报错:Job for nginx.service invalid问题

《解决systemctlreloadnginx重启Nginx服务报错:Jobfornginx.serviceinvalid问题》文章描述了通过`systemctlstatusnginx.se... 目录systemctl reload nginx重启Nginx服务报错:Job for nginx.javas

Redis缓存问题与缓存更新机制详解

《Redis缓存问题与缓存更新机制详解》本文主要介绍了缓存问题及其解决方案,包括缓存穿透、缓存击穿、缓存雪崩等问题的成因以及相应的预防和解决方法,同时,还详细探讨了缓存更新机制,包括不同情况下的缓存更... 目录一、缓存问题1.1 缓存穿透1.1.1 问题来源1.1.2 解决方案1.2 缓存击穿1.2.1

深入理解Redis大key的危害及解决方案

《深入理解Redis大key的危害及解决方案》本文主要介绍了深入理解Redis大key的危害及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着... 目录一、背景二、什么是大key三、大key评价标准四、大key 产生的原因与场景五、大key影响与危

vue解决子组件样式覆盖问题scoped deep

《vue解决子组件样式覆盖问题scopeddeep》文章主要介绍了在Vue项目中处理全局样式和局部样式的方法,包括使用scoped属性和深度选择器(/deep/)来覆盖子组件的样式,作者建议所有组件... 目录前言scoped分析deep分析使用总结所有组件必须加scoped父组件覆盖子组件使用deep前言