LLaVA-MoLE:解决多模态大模型指令微调中的数据冲突问题

2024-09-02 08:12

本文主要是介绍LLaVA-MoLE:解决多模态大模型指令微调中的数据冲突问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

人工智能咨询培训老师叶梓 转载标明出处

多模态大模型(MLLMs)通过指令微调(instruction finetuning),能够执行各种任务,如理解图表、处理文档和回答基于图像的问题。但是,当从不同领域混合指令数据进行微调时,模型在特定领域的任务上可能会出现性能下降。这种现象被称为数据冲突,它限制了通过增加新领域训练数据来扩展MLLM能力的可能性。为了应对这一挑战,来自美团公司的研究者们提出了一种新颖的方法——LLaVA-MoLE,即稀疏混合LoRA专家(Sparse Mixture of LoRA Experts)。

该模型基于LLaVA-1.5,通过在Transformer层中引入一组LoRA(Low-Rank Adaption)专家,并为每个token选择最适合的专家进行处理。这种设计允许模型根据不同领域的token激活不同的专家,从而扩展了MLLM处理多领域数据的能力。

论文链接:https://arxiv.org/pdf/2401.16160

方法

低秩适应(LoRA)是一种针对大模型(LLMs)的参数高效微调方法。它能够应用于任意线性层。具体来说,对于一个输入为 和权重矩阵 ​ 的线性层 h=Wx,LoRA 学习一个低秩分解的更新:

其中,是低秩矩阵,r 是远小于d_o​ 和 d_i​ 的秩,α 控制对原始W 的变化幅度。在学习LoRA模块过程中,只有矩阵A 和 B 会被更新。

图 2 展示了 LLaVA-MoLE 模型的整体框架,该模型基于 LLaVA-1.5 构建,采用了稀疏混合 LoRA 专家(Sparse Mixture of LoRA Experts)的方法来训练。

  1. 输入图像处理:输入图像首先通过 CLIP ViT(Vision Transformer)进行处理,CLIP ViT 是一种视觉编码器,能够将图像转换成一系列的视觉嵌入(visual embeddings)。之后,这些视觉嵌入通过一个两层的多层感知器(MLP)进行进一步的映射。

  2. 文本输入处理:文本输入首先被分词(tokenized),然后通过词嵌入矩阵转换成嵌入表示,这些嵌入与视觉输入一起被串联(concatenated),形成最终输入到大型语言模型(LLM)的混合嵌入序列。

  3. 稀疏混合 LoRA 专家:在 LLaVA-MoLE 模型中,每个 Transformer 层都采用了提出的稀疏混合 LoRA 专家进行训练。具体来说,每个全连接层(FFN)都会根据路由器(router)的输出分布选择并结合一个 LoRA 专家来进行计算。

  4. 路由器(Router)的作用:路由器负责为每个 token 分配一个最合适的 LoRA 专家。路由器的输出分布决定了 FFN 应该选择哪个专家来处理当前的 token。

  5. 自注意力(Self-Attention)训练:自注意力机制同样采用 LoRA 进行训练,但在这个框架中没有应用专家混合(MoE)。

  6. 计算并行化:对于每个 LoRA 专家,相同子序列的 token 可以并行计算,这提高了模型训练的效率。

如图 2 所示,一个MLLM可以被表述为:

其中 是视觉编码器和适配器,将输入图像映射成一系列视觉嵌入,将输入问题 T_q​ 进行标记化并用词嵌入矩阵嵌入离散标记,而 ∣∣ 是序列连接操作。因此,MLLM的输入实际上是一个混合嵌入序列。训练MLLM的指令数据被组织成三元组 (),不同的指令数据集可能有不同的分布,导致训练出的MLLM表现出不同的行为或专长。

为了缓解混合不同类型的指令数据时产生的冲突,研究者引入了一组LoRA专家和一个路由器。在每个输入token上,路由器学习选择最合适的专家激活,使模型具有额外的能力来处理不同类型的输入。假设每层有K 个专家,选择具有最高路由函数值的专家:

然后激活选定的专家来执行实际计算,而忽略当前token的其他专家。例如,对于现代LLMs中的FFN层通常是多层的,每一层的FFN都会有一个单独的MoE,但它们共享相同的路由器。通过只激活top-1专家,实际计算成本与原始FFN中的plain-LoRA大致相同。

为了确保模型的高效运行,研究者还引入了负载平衡损失,以避免专家分配的严重不平衡。负载平衡损失的公式为:

其中 cj​ 是分配给第j 个专家的token数量,pj​ 是第 j 个专家的总路由概率。通过最小化,专家的分配趋于均匀,从而避免了某些专家过载而另一些专家闲置的问题。

通过上述方法,LLaVA-MoLE模型能够有效地解决数据冲突问题,同时保持了计算成本的可控性,为多模态大型语言模型的微调提供了一种有效的解决方案。

想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具。9月22日晚,实战专家1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。

加助理微信提供直播链接:amliy007,29.9元即可参加线上直播分享,叶老师亲自指导,互动沟通,全面掌握Llama Factory,关注享粉丝福利,限时免费CSDN听直播后的录播讲解。
LLaMA Factory 支持多种预训练模型和微调算法。它提供灵活的运算精度和优化算法选择,以及丰富的实验监控工具。开源特性和社区支持使其易于使用,适合各类用户快速提升模型性能。

实验

基本模型架构遵循LLaVA1.5的设计,其中使用了CLIP ViT-L作为视觉编码器,输入图像分辨率为336x336,补丁大小为14。适配器是一个两层的MLP,用于转换来自ViT的576个token。大型语言模型(LLM)是Vicuna-7B-v1.5。在所有实验的训练过程中,ViT和Vicuna的权重都被冻结。除非特别说明,否则应用于LLM的LoRA秩是32。

模型在两个阶段进行训练:预训练和指令微调。预训练阶段使用了ShareGPT4V数据集,包含由GPT4V生成的数据训练的标题器产生的130万个详细的字幕数据。指令微调阶段,采用了来自三个不同领域的多模态指令数据集:一般多任务、文档和生物医学。M3IT和ShareGPT4V Instruct是两个一般多任务指令数据集,而UReader收集的文档导向指令数据集包含来自多个公共数据集的图像和指令。还使用了PathVQA作为生物医学领域的指令数据。所有这些数据集都是公开的,并且按照UReader的数据划分进行训练和测试。表格1列出了预训练(PT)和监督指令微调(SFT)阶段的训练参数。

表 2 展示了在不同数据和MoE配置下训练的模型的实验结果。首先提供了官方LLaVA-1.5和LLaVA-Med模型在每个基准测试上的结果。然后,通过在不同数据集上单独训练plain-LoRA模型,并将其命名为LLaVA-1.5、LLaVA-Doc和LLaVA-Med。这些模型在与其训练数据集相对应的基准测试上的性能被视为该基准的基线性能。例如,特别重现的LLaVA-1.5†专门在一般多任务指令数据上训练,在Tiny LVLM-eHub上实现了与官方LLaVA-1.5 (307.2)相当的306.3的总分。通过混合不同数据集,发现LLaVA-Mix在eHub的整体性能比LLaVA1.5†降低了7-9分。这表明一般多任务数据与这些数据类型之间存在冲突,这种冲突可能会损害模型的一般多任务QA能力。提出的LLaVA-MoLE成功地解决了上述冲突。通过比较LLaVAMoLE[1,1,0]与LLaVA-Mix[1,1,0],可以观察到eHub的整体性能显著提高,与基线LLaVA-1.5†相当,而UReader基准测试的性能甚至超过了基线LLaVA-Doc†,例如在ChartQA上绝对性能提高了6.4。这可以证明混合专家已经学会了处理不同类型的指令数据并减少潜在的数据冲突。

表 3 展示了在不同LoRA秩下训练的模型的实验结果。可以看到,对于LoRA秩32、64和96,将文档指令数据与一般多任务指令数据混合都会导致eHub基准测试的性能下降。但通过比较实验LLaVA-Mix[1,1]-R32、LLaVA-Mix[1,1]-R64和LLaVA-Mix[1,1]-R96的结果,也发现增加LoRA秩,即增加模型容量,可以在一定程度上缓解数据冲突问题:eHub的总分从R32的298.8增加到R96的301.1。此外,如果将LoRA秩增加到128,似乎解决了这个问题。然而,作者认为简单地提高模型容量是一种昂贵的解决方案,会导致训练过程中的计算和内存增加。而提出的LLaVA-MoLE可以在不增加太多额外成本的情况下解决这个问题。值得注意的是,对于较小(32)和较大(128)的LoRA秩,LLaVAMoLE在两个基准测试上都显著优于LLaVA-Mix。

图 3 展示了在所有三个数据集的混合上训练的LLaVA-MoLE模型的路由选择的粗略分析。通过计算每个基准测试中分配给每个专家的token比例的均值和标准差,对第0层、第2层、第10层和第28层的结果进行了可视化。对于某些层,例如第2层和第10层,不同类型数据的专家选择模式相似,但在不同层之间有所不同。也有一些层(第10层和第28层),每种类型的数据都有自己的专家选择模式。没有观察到明显的模式表明某个特定专家在其他专家中一直更受青睐。但某些专家可能在特定数据集上比其他专家更倾向于被选择,例如,专家0在所有层的PathVQA样本中更频繁地被激活。

通过这些详细的实验设置和结果分析,证明了LLaVA-MoLE模型在解决多模态大型语言模型指令微调中的数据冲突问题方面是有效的,并且能够在保持计算成本可控的同时提高模型性能。

这篇关于LLaVA-MoLE:解决多模态大模型指令微调中的数据冲突问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1129454

相关文章

MybatisGenerator文件生成不出对应文件的问题

《MybatisGenerator文件生成不出对应文件的问题》本文介绍了使用MybatisGenerator生成文件时遇到的问题及解决方法,主要步骤包括检查目标表是否存在、是否能连接到数据库、配置生成... 目录MyBATisGenerator 文件生成不出对应文件先在项目结构里引入“targetProje

C#使用HttpClient进行Post请求出现超时问题的解决及优化

《C#使用HttpClient进行Post请求出现超时问题的解决及优化》最近我的控制台程序发现有时候总是出现请求超时等问题,通常好几分钟最多只有3-4个请求,在使用apipost发现并发10个5分钟也... 目录优化结论单例HttpClient连接池耗尽和并发并发异步最终优化后优化结论我直接上优化结论吧,

Java内存泄漏问题的排查、优化与最佳实践

《Java内存泄漏问题的排查、优化与最佳实践》在Java开发中,内存泄漏是一个常见且令人头疼的问题,内存泄漏指的是程序在运行过程中,已经不再使用的对象没有被及时释放,从而导致内存占用不断增加,最终... 目录引言1. 什么是内存泄漏?常见的内存泄漏情况2. 如何排查 Java 中的内存泄漏?2.1 使用 J

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

Python MySQL如何通过Binlog获取变更记录恢复数据

《PythonMySQL如何通过Binlog获取变更记录恢复数据》本文介绍了如何使用Python和pymysqlreplication库通过MySQL的二进制日志(Binlog)获取数据库的变更记录... 目录python mysql通过Binlog获取变更记录恢复数据1.安装pymysqlreplicat

Linux使用dd命令来复制和转换数据的操作方法

《Linux使用dd命令来复制和转换数据的操作方法》Linux中的dd命令是一个功能强大的数据复制和转换实用程序,它以较低级别运行,通常用于创建可启动的USB驱动器、克隆磁盘和生成随机数据等任务,本文... 目录简介功能和能力语法常用选项示例用法基础用法创建可启动www.chinasem.cn的 USB 驱动

numpy求解线性代数相关问题

《numpy求解线性代数相关问题》本文主要介绍了numpy求解线性代数相关问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 在numpy中有numpy.array类型和numpy.mat类型,前者是数组类型,后者是矩阵类型。数组

解决systemctl reload nginx重启Nginx服务报错:Job for nginx.service invalid问题

《解决systemctlreloadnginx重启Nginx服务报错:Jobfornginx.serviceinvalid问题》文章描述了通过`systemctlstatusnginx.se... 目录systemctl reload nginx重启Nginx服务报错:Job for nginx.javas

Oracle数据库使用 listagg去重删除重复数据的方法汇总

《Oracle数据库使用listagg去重删除重复数据的方法汇总》文章介绍了在Oracle数据库中使用LISTAGG和XMLAGG函数进行字符串聚合并去重的方法,包括去重聚合、使用XML解析和CLO... 目录案例表第一种:使用wm_concat() + distinct去重聚合第二种:使用listagg,

Redis缓存问题与缓存更新机制详解

《Redis缓存问题与缓存更新机制详解》本文主要介绍了缓存问题及其解决方案,包括缓存穿透、缓存击穿、缓存雪崩等问题的成因以及相应的预防和解决方法,同时,还详细探讨了缓存更新机制,包括不同情况下的缓存更... 目录一、缓存问题1.1 缓存穿透1.1.1 问题来源1.1.2 解决方案1.2 缓存击穿1.2.1