MOELoRA —— 多任务医学应用中的参数高效微调方法

2024-09-01 22:44

本文主要是介绍MOELoRA —— 多任务医学应用中的参数高效微调方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

人工智能咨询培训老师叶梓 转载标明出处

在医疗场景中,LLMs可以应用于多种不同的任务,如医生推荐、诊断预测、药物推荐、医学实体识别、临床报告生成等。这些任务的输入和输出差异很大,给统一模型的微调带来了挑战。而且LLMs的参数众多,导致微调过程中时间和计算资源的消耗巨大。针对这些问题,来自西安交通大学、香港城市大学、腾讯YouTu Lab等机构的研究者们提出了一种新颖的参数高效微调框架——MOELoRA。它结合了多任务学习和参数高效微调的优点,通过设计多个专家(Experts)作为可训练参数,每个专家由一对低秩矩阵组成,以保持可训练参数的数量较小。研究者还提出了一种任务驱动的门控函数,用于调节每个专家的贡献,并为不同任务生成不同的参数。

论文链接:https://arxiv.org/pdf/2310.18339

项目链接: https://github.com/liuqidong07/MOELoRA-peft

方法

图3 为使用MOELoRA进行LLMs参数高效微调和推理过程。在参数高效微调领域,LoRA方法引入了仅训练两个低秩矩阵来替代密集层更新的概念。基于此本方法将MOELoRA层集成到每个密集层中,使它们能够获取键、查询和值,同时促进前馈网络(FNN)的运作。图3中以FNN为例进行说明。该方法的一个显著优势是,研究者只为不同任务微调MOELoRA层的参数,而保持原始LLMs的其他参数不变。另外每个MOELoRA层包含多个专家,这些专家旨在捕获不同医学任务的多样化知识。研究者引入了一个任务驱动的门控函数,以确保为每个任务学习到独特的参数集。这个函数决定了所有MOELoRA层中专家的贡献权重,从而生成针对不同任务量身定制的独特更新参数。研究者为所有MOELoRA层使用单个门控函数,而不是让门控函数与MOELoRA层一一对应。在微调过程中,研究者更新来自所有任务混合数据的MOELoRA层。然后,在推理过程中,MOELoRA可以为每个任务派生出不同的微调权重。

LoRA方法在LLMs的微调中展示了其有效性和效率。它受到低内在维度特性的启发,将LLMs中的参数微调过程重新定义为低秩分解。具体而言方程式W_0​+ΔW=W+BA捕捉了这种分解。这里,​代表预训练LLMs的参数矩阵,而​表示在微调过程中更新的矩阵。矩阵B∈是低秩且可训练的。给定这样的设置,与LoRA层配对的线性层的前向过程可以表示为:

其中,x代表维度为d_in​的输入向量,ℎ是维度为d_out​的输出向量。可训练低秩矩阵的秩由r表示,它决定了可训练参数的数量。常数超参数α促进了秩r的调整。在LoRA微调过程中,LLMs中的所有参数保持不变。只有低秩矩阵A和B会进行微调。鉴于r≪d_in​且r≪d_out​,A和B中的参数总数比W_0​中的要少得多。这样的特性使得微调过程实现了参数效率。然而,原始LoRA中所有任务的集成参数微调会导致学习医学知识各个方面的困难。一个潜在的解决方案是将整个参数集分割成几个部分,并为各种任务得出不同的组合。专家混合模型(MOE)建议使用多个专家网络来捕获多任务信息的不同方面,这与组合概念相符。这一洞见引导研究者设计了MOELoRA,它无缝集成了LoRA和MOE的优势。为了协调LoRA和MOE的不同前向过程,研究者引入了一组专家来学习更新矩阵ΔW。由于MOELoRA使用来自所有任务的数据对专家进行微调,它内在地捕获了共享任务知识。为了保持紧凑的参数大小,MOELoRA层中的每个专家都构建为两个分解的低秩矩阵。基于这种结构,对于来自任务T_j​的样本,与MOELoRA层配对的线性层的前向过程表示为:

其中,h_j​和x_j​代表来自T_j​的中间LLM层的输入和输出。矩阵形成专家E_i​。超参数N表示MOELoRA中的专家数量,对于每个专家,矩阵A和B的秩是r/N​。在方程(4)中,术语调节这些贡献权重,用于任务T_j​。这个权重由研究者提出的门控函数确定。这里,研究者将讨论LoRA和MOELoRA的可训练参数数量。就LoRA而言,两个低秩矩阵​包含所有可训练参数。因此,LoRA的可训练参数数量是​=。至于MOELoRA,有N个可训练专家,每个专家拥有,所以总数计算为。总之,MOELoRA具有与LoRA相同数量的可训练参数,这表明了高效率。

如前所述,每个专家的贡献应该针对特定任务进行定制。为了调节这些贡献,研究者引入了一个门控函数。由于这些权重本质上是任务特定的,研究者的门控函数被设计为将任务身份作为输入。研究者采用了一个任务嵌入矩阵,记为​,其中d_T​代表任务嵌入的维度。确定任务T_j​后,研究者提取E的第j列,作为该任务的表示向量,记为​。为了确定任务T_j​的贡献权重,研究者应用线性变换。这一计算被以下方程捕获:

这里,代表为任务T_j​量身定制的贡献权重向量。变换矩阵记为。为了防止权重过大,研究者采用softmax操作来归一化贡献权重。图3中提到的门控自然是一个密集设计,以结合所有专家。研究者还设计了一个稀疏版本的任务驱动门控,以探索哪种设计更有效。设计的稀疏门控如下公式:

与传统的MOE设计直接将输入向量x输入门控函数不同,研究者的方法不同。研究者仅将任务身份输入门控函数,如图3所示,旨在为每个任务产生一组独特的模型参数。例如,如果某人希望恢复任务T_j​的微调参数,则该过程可以表述为:

如果门控函数由输入向量x驱动,权重向量将因样本而异。这意味着每个样本将拥有其独特的ω_j​,导致特定于样本的微调参数矩阵。这种设计将使参数无法按任务恢复。能够为每个任务恢复参数提供了两个主要优势:

1) 任务定制:每个任务都使用一组参数进行微调,这有助于学习更多任务特定的信息并缓解数据不平衡问题。

2) 推理效率:恢复的微调LLMs表现出降低的推理延迟。这归因于消除了与MOELoRA层相关的额外前向计算的需要。

研究者也在算法1中总结了整个过程:

微调:研究者首先根据LLMs中指定的层和几个超参数配置MOELoRA(第1-3行)。然后,对于参数高效微调,所有预训练的LLMs中的参数(第4行)都被冻结。在微调过程中,研究者迭代地从所有任务中随机抽取一批数据,而不是像一些多任务研究那样将来自同一任务的样本分到一个批次中。研究者通过实验中的性能比较选择了随机抽样批次。使用这批数据,研究者可以进行前向过程并计算微调的损失(第6-7行)。对于参数更新,研究者只微调MOELoRA和任务驱动门控函数的参数,即

推理:MOELoRA可以通过方程(8)为每个任务恢复微调的参数矩阵。对于推理,首先恢复每个任务的微调参数(第10-13行),这表明每个任务都有自己的LLMs参数。可以应用相应的LLMs来完成指定的任务。

想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具。9月22日晚,实战专家1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。

加助理微信提供直播链接:amliy007,29.9元即可参加线上直播分享,叶老师亲自指导,互动沟通,全面掌握Llama Factory,关注享粉丝福利,限时免费CSDN听直播后的录播讲解。
LLaMA Factory 支持多种预训练模型和微调算法。它提供灵活的运算精度和优化算法选择,以及丰富的实验监控工具。开源特性和社区支持使其易于使用,适合各类用户快速提升模型性能。

实验

研究者在PromptCBLUE数据集上进行实验,这是一个多任务中文医疗数据集,在天池竞赛平台上提供。该数据集包含16个不同的医疗任务,每个任务都使用特定的提示转换为纯文本格式,确保与LLMs兼容。由于计算限制,研究者随机选择了8个任务进行实验。在预处理中,他们从原始数据集中删除了重复样本。由于竞赛中使用的测试集尚未发布,研究者选择使用开发集作为测试集。然后,实验的验证集从竞赛的训练集中得出,其大小与测试集匹配。数据集的统计信息在表1中总结。

研究者将MOELoRA与四组不同的基线进行比较:

  • 未微调的LLMs:使用In-Context Learning来指导LLMs完成任务。
  • 微调的LLMs:包括P-Tuning、LoRA (Full)、LoRA (Single)和LoRA (Full+TP)等策略。
  • 模型编辑:Task-Arithmetic方法。
  • 跨任务泛化:评估LoRAHub和MoLoRA方法。

实验使用PyTorch 1.12.0和Python 3.9.5进行模拟,代码在Tesla V100 32G GPU上运行以加速。ChatGLM-6B作为微调的基础模型。对于所有LoRA微调基线和提出的MOELoRA,指定了可训练层。输入和输出长度分别配置为1,024和196。批量大小设置为64,最多8,000个训练步骤。LoRA的秩𝑟固定为16,LoRA dropout 𝛼 = 0.1。对于MOELoRA,专家数量设置为8。

研究者采用多种指标来评估每个任务的性质。例如,CMeIE任务使用Micro-F1,而CHIP-CTC和KUAKE-QIC任务使用Macro-F1。对于文本生成任务,如IMCS-V2-MRG和MedDG,应用Rouge-L。所有任务的平均分数用于评估整体性能。

表2展示了MOELoRA与竞争基线的整体实验结果。MOELoRA(D)和MOELoRA(S)分别代表MOELoRA的密集和稀疏门控设计。分析所有任务的平均分数,MOELoRA(D)在所有方法中表现最佳:

  • 未微调的LLMs:明显落后于其他组,突出了微调LLMs以融入特定任务医学知识的重要性。
  • 参数高效微调策略:LoRA基础方法明显优于P-Tuning。LoRA (Full)和LoRA (Full+TP)都利用所有任务的数据,但LoRA (Full+TP)略逊一筹,可能归因于任务提示的添加,导致输入文本的扩展,可能由于输入长度限制而截断信息词。
  • 模型编辑:Task-Arithmetic明显落后于所有微调竞争对手。
  • 跨任务泛化:尽管在跨任务泛化设置中表现令人印象深刻,但它们需要大量任务数据,这与多任务设置相冲突。

表3中展示了消融研究的结果。没有MOE架构的变体(即LoRA(Full))表现较差,强调了MOE架构的重要性。同样,没有门控功能的变体也落后于MOELoRA,突出了门控功能的有效性。多个门控功能的变体由于过度参数化而表现稍差。

为了回答RQ3,研究者探讨了超参数对MOELoRA(D)性能的影响。特别是,专家数量𝑁和LoRA秩𝑟的变化如何影响结果。发现随着𝑁从0增加到8,性能得到改善,但当𝑁增加到16时,性能略有下降。同时,增加𝑟可以提高性能,但也会导致可训练参数数量的增加。

为了评估训练和推理效率,研究者在图5中比较了可调参数的比例和推理延迟。MOELoRA在训练和推理效率方面与LoRA (Full)相当,通过训练不超过LLMs的0.48%参数来节省资源。MoLoRA和MOELoRA(M)需要更多的可训练参数,因为它们为每个可训练的低秩层设置了额外的门控。在推理方面,所有模型都需要相同的推理延迟,除了MoLoRA,因为它无法像方程(8)那样恢复微调参数,所以需要在推理时伴随MoLoRA层,导致额外的前向计算引起的更多推理延迟。

为了回答RQ5,图6中展示了四个任务的专家权重。每个任务中不同颜色的条形长度代表相应专家的权重。这表明不同专家在不同医疗任务中专门捕获特定方面的知识,强调了MOELoRA在利用共享知识以惠及相关任务方面的熟练程度。

实验结果表明,MOELoRA在性能上超越了现有的参数高效微调方法。这一研究成果不仅为医疗领域的LLMs应用提供了新的思路,也为其他领域的多任务学习提供了参考。

这篇关于MOELoRA —— 多任务医学应用中的参数高效微调方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1128253

相关文章

Window Server2016加入AD域的方法步骤

《WindowServer2016加入AD域的方法步骤》:本文主要介绍WindowServer2016加入AD域的方法步骤,包括配置DNS、检测ping通、更改计算机域、输入账号密码、重启服务... 目录一、 准备条件二、配置ServerB加入ServerA的AD域(test.ly)三、查看加入AD域后的变

Window Server2016 AD域的创建的方法步骤

《WindowServer2016AD域的创建的方法步骤》本文主要介绍了WindowServer2016AD域的创建的方法步骤,文中通过图文介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录一、准备条件二、在ServerA服务器中常见AD域管理器:三、创建AD域,域地址为“test.ly”

NFS实现多服务器文件的共享的方法步骤

《NFS实现多服务器文件的共享的方法步骤》NFS允许网络中的计算机之间共享资源,客户端可以透明地读写远端NFS服务器上的文件,本文就来介绍一下NFS实现多服务器文件的共享的方法步骤,感兴趣的可以了解一... 目录一、简介二、部署1、准备1、服务端和客户端:安装nfs-utils2、服务端:创建共享目录3、服

MySQL中时区参数time_zone解读

《MySQL中时区参数time_zone解读》MySQL时区参数time_zone用于控制系统函数和字段的DEFAULTCURRENT_TIMESTAMP属性,修改时区可能会影响timestamp类型... 目录前言1.时区参数影响2.如何设置3.字段类型选择总结前言mysql 时区参数 time_zon

Java 字符数组转字符串的常用方法

《Java字符数组转字符串的常用方法》文章总结了在Java中将字符数组转换为字符串的几种常用方法,包括使用String构造函数、String.valueOf()方法、StringBuilder以及A... 目录1. 使用String构造函数1.1 基本转换方法1.2 注意事项2. 使用String.valu

Python实现高效地读写大型文件

《Python实现高效地读写大型文件》Python如何读写的是大型文件,有没有什么方法来提高效率呢,这篇文章就来和大家聊聊如何在Python中高效地读写大型文件,需要的可以了解下... 目录一、逐行读取大型文件二、分块读取大型文件三、使用 mmap 模块进行内存映射文件操作(适用于大文件)四、使用 pand

Python中使用defaultdict和Counter的方法

《Python中使用defaultdict和Counter的方法》本文深入探讨了Python中的两个强大工具——defaultdict和Counter,并详细介绍了它们的工作原理、应用场景以及在实际编... 目录引言defaultdict的深入应用什么是defaultdictdefaultdict的工作原理

使用Python进行文件读写操作的基本方法

《使用Python进行文件读写操作的基本方法》今天的内容来介绍Python中进行文件读写操作的方法,这在学习Python时是必不可少的技术点,希望可以帮助到正在学习python的小伙伴,以下是Pyth... 目录一、文件读取:二、文件写入:三、文件追加:四、文件读写的二进制模式:五、使用 json 模块读写

Python如何使用seleniumwire接管Chrome查看控制台中参数

《Python如何使用seleniumwire接管Chrome查看控制台中参数》文章介绍了如何使用Python的seleniumwire库来接管Chrome浏览器,并通过控制台查看接口参数,本文给大家... 1、cmd打开控制台,启动谷歌并制定端口号,找不到文件的加环境变量chrome.exe --rem

Oracle数据库使用 listagg去重删除重复数据的方法汇总

《Oracle数据库使用listagg去重删除重复数据的方法汇总》文章介绍了在Oracle数据库中使用LISTAGG和XMLAGG函数进行字符串聚合并去重的方法,包括去重聚合、使用XML解析和CLO... 目录案例表第一种:使用wm_concat() + distinct去重聚合第二种:使用listagg,