Multi-gate Mixture-of-Experts(MMoE)

2023-10-06 23:50
文章标签 multi mixture gate experts mmoe

本文主要是介绍Multi-gate Mixture-of-Experts(MMoE),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 概述

在工业界经常会面对多个学习目标的场景,如在推荐系统中,除了要给用户推荐刚兴趣的物品之外,一些细化的指标,包括点击率,转化率,浏览时长等等,都会作为评判推荐系统效果好坏的重要指标,不同的是在不同的场景下对不同指标的要求不一样而已。在面对这种多任务的场景,最简单最直接的方法是针对每一个任务训练一个模型,显而易见,这种方式带来了巨大的成本开销,包括了计算成本和存储成本。多任务学习(Multi-task Learning)便由此而生,在多任务学习中,希望通过一个模型可以同时学习多个目标。然而在多任务学习中,多个任务之间通常存在着或是彼此联系或是巨大差异的现象,这就导致了多任务模型常常效果不佳。Google于2018年提出了Multi-gate Mixture-of-Experts(MMoE)模型[1]来对任务之间相互关系建模。

2. 算法原理

MMoE模型并不是凭空产出的,是在前人的工作上做了很多改进。多任务学习经过多年的发展,历史上也出现了很多多任务学习的模型。

2.1. Shared-Bottom模型

在多任务学习模型当中,最常见的一种模型就是shared-bottom模型,shared-bottom模型的结构如下图所示:

在这里插入图片描述

在shared-bottom模型中,每个任务都共享底部的网络,如上图中的Shared Bottom部分,然后在上层再根据任务的不同划分出多个tower network来分别学习不同的目标,如上图中的TowerA和Tower B。假设当前有 K K K个任务,输入特征通过shared-bottom网络后可以由函数 f f f表示,每一个tower网络的输出为函数 h k h^k hk,其中 k = 1 , 2 , ⋯ , K k=1,2,\cdots ,K k=1,2,,K,则shared-bottom模型可以表示为:

y k = h k ( f ( x ) ) y_k=h^k\left ( f\left ( x \right ) \right ) yk=hk(f(x))

2.2. Multi-gate Mixture-of-Experts(MMoE)模型

从MMoE的名称来看,可以看到主要包括两个部分,分别为:Multi-gate(多门控网络)和Mixture-of-Experts(混合专家)。

2.2.1. Mixture-of-Experts(MoE)模型

MoE模型可以表示为

y = ∑ i = 1 n g ( x ) i f i ( x ) y=\sum_{i=1}^{n}g\left ( x \right )_if_i\left ( x \right ) y=i=1ng(x)ifi(x)

其中 ∑ i = 1 n g ( x ) i = 1 \sum_{i=1}^{n}g\left ( x \right )_i=1 i=1ng(x)i=1 g ( x ) i g\left ( x \right )_i g(x)i表示的是 g ( x ) g\left ( x \right ) g(x)的第 i i i个输出值,代表的是选择专家 f i f_i fi的概率值。 f i ( x ) f_i\left ( x \right ) fi(x)是第 i i i个专家网络的值。MoE可以看作是基于多个独立模型的集成方法Ensemble,通过Ensemble的知识可知,通过Ensemble能够提高模型的性能。

也有将MoE作为一个独立的层[2],将多个MoE结构堆叠在另一个网络中,一个MoE层的输出作为下一层MoE层的输入,其输出作为另一个下一层的输入,其具体过程如下图所示:

在这里插入图片描述

2.2.2. One-gate Mixture-of-Experts(OMoE)模型

在shared-bottom模型中,无法实现对多个任务之间关系的建模,结合shared-bottom和MoE,便有了One-gate Mixture-of-Experts模型,其具体过程如下图所示:

在这里插入图片描述
假设当前有 K K K个任务,与Shared-Bottom模型一样,输入特征通过多个专家网络后可以由函数 f i ( x ) f_i\left ( x \right ) fi(x)表示,假设当前有 n n n个专家网络,即 i = 1 , 2 , ⋯ , n i=1,2,\cdots ,n i=1,2,,n,每一个任务对应的tower网络的输出为函数 h k h^k hk,其中 k = 1 , 2 , ⋯ , K k=1,2,\cdots ,K k=1,2,,K,则OMoE模型可以表示为:

y k = h k ( ∑ i = 1 n g ( x ) i f i ( x ) ) y_k=h^k\left ( \sum_{i=1}^{n}g\left ( x \right )_if_i\left ( x \right ) \right ) yk=hk(i=1ng(x)ifi(x))

2.2.3. Multi-gate Mixture-of-Experts(MMoE)模型

Multi-gate Mixture-of-Experts是One-gate Mixture-of-Experts的升级版本,借鉴门控网络的思想,将OMoE模型中的One-gate升级为Multi-gate,针对不同的任务有自己独立的门控网络,每个任务的gating networks通过最终输出权重不同实现对专家的选择。不同任务的门控网络可以学习到对专家的不同组合,因此模型能够考虑到了任务之间的相关性和区别。其具体过程如下图所示:

在这里插入图片描述

同样,假设当前有 K K K个任务,输入特征通过多个专家网络后可以由函数 f i ( x ) f_i\left ( x \right ) fi(x)表示,假设当前有 n n n个专家网络,即 i = 1 , 2 , ⋯ , n i=1,2,\cdots ,n i=1,2,,n,每一个任务对应的tower网络的输出为函数 h k h^k hk,其中 k = 1 , 2 , ⋯ , K k=1,2,\cdots ,K k=1,2,,K,则MMoE模型可以表示为:

y k = h k ( ∑ i = 1 n g k ( x ) i f i ( x ) ) y_k=h^k\left ( \sum_{i=1}^{n}g^k\left ( x \right )_if_i\left ( x \right ) \right ) yk=hk(i=1ngk(x)ifi(x))

其中, g k ( x ) i g^k\left ( x \right )_i gk(x)i为门控网络,可以由一个最简单的网络表示:

g k ( x ) = s o f t m a x ( W g k x ) g^k\left ( x \right )=softmax\left ( W_{gk}x \right ) gk(x)=softmax(Wgkx)

其中, W g k ∈ R n × d W_{gk}\in \mathbb{R}^{n\times d} WgkRn×d n n n表示专家的个数, d d d表示的是特征的维度。

3.总结

通过结合门控网络和混合专家组成的MMoE模型,从实验的结论上来看,能够利用同一个模型对多个任务同时建模,同时能够对多个任务之间的联系和区别建模。

参考文献

[1] Ma J, Zhao Z, Yi X, et al. Modeling task relationships in multi-task learning with multi-gate mixture-of-experts[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2018: 1930-1939.

[2] Shazeer, Noam, et al. “Outrageously large neural networks: The sparsely-gated mixture-of-experts layer.” arXiv preprint arXiv:1701.06538 (2017).

这篇关于Multi-gate Mixture-of-Experts(MMoE)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/154703

相关文章

多头注意力机制(Multi-Head Attention)

文章目录 多头注意力机制的作用多头注意力机制的工作原理为什么使用多头注意力机制?代码示例 多头注意力机制(Multi-Head Attention)是Transformer架构中的一个核心组件。它在机器翻译、自然语言处理(NLP)等领域取得了显著的成功。多头注意力机制的引入是为了增强模型的能力,使其能够从不同的角度关注输入序列的不同部分,从而捕捉更多层次的信息。 多头注意力机

【C++11 之新增容器 array、foward_list、tuple、unordered_(multi)map/set】应知应会

C++11 标准中新增了多个容器,这些容器为 C++ 程序员提供了更多的选择,以满足不同的编程需求。以下是对这些新容器的介绍和使用案例: std::array 介绍: std::array 是一个固定大小的数组容器,它在栈上分配内存,并提供了类似于标准库容器的接口。它提供了更好的类型安全性和范围检查,同时保持了与原生数组相似的性能。std::array 的大小必须在编译时确定,并且不能更改。

Elasticsearch java API (10)Multi Get API

Multi Get API编辑 多让API允许基于他们得到的文档列表 index, type和 id: MultiGetResponse multiGetItemResponses = client.prepareMultiGet().add("twitter", "tweet", "1") // <1> .add("twitter", "tweet

越复杂的CoT越有效吗?Complexity-Based Prompting for Multi-step Reasoning

Complexity-Based Prompting for Multi-step Reasoning 论文:https://openreview.net/pdf?id=yf1icZHC-l9 Github:https://github.com/FranxYao/chain-of-thought-hub 发表位置:ICLR 2023 Complexity-Based Prompting for

论文学习 Learning Robust Representations via Multi-View Information Bottleneck

Code available at https://github.com/mfederici/Multi-View-Information-Bottleneck 摘要:信息瓶颈原理为表示学习提供了一种信息论方法,通过训练编码器保留与预测标签相关的所有信息,同时最小化表示中其他多余信息的数量。然而,最初的公式需要标记数据来识别多余的信息。在这项工作中,我们将这种能力扩展到多视图无监督设置,其中提供

【论文阅读】MOA,《Mixture-of-Agents Enhances Large Language Model Capabilities》

前面大概了解了Together AI的新研究MoA,比较好奇具体的实现方法,所以再来看一下对应的文章论文。 论文:《Mixture-of-Agents Enhances Large Language Model Capabilities》 论文链接:https://arxiv.org/html/2406.04692v1 这篇文章的标题是《Mixture-of-Agents Enhances

Multi-Head RAG:多头注意力的激活层作为嵌入进行文档检索

现有的RAG解决方案可能因为最相关的文档的嵌入可能在嵌入空间中相距很远,这样会导致检索过程变得复杂并且无效。为了解决这个问题,论文引入了多头RAG (MRAG),这是一种利用Transformer的多头注意层的激活而不是解码器层作为获取多方面文档的新方案。 MRAG 不是利用最后一个前馈解码器层为最后一个令牌生成的单个激活向量,而是利用最后一个注意力层为最后一个令牌生成的H个单独的激活向量,然

BEV 中 multi-frame fusion 多侦融合(一)

文章目录 参数设置align_dynamic_thing:为了将动态物体的点云数据从上一帧对齐到当前帧流程 旋转函数平移公式filter_points_in_ego:筛选出属于特定实例的点get_intermediate_frame_info: 函数用于获取中间帧的信息,包括点云数据、传感器校准信息、自车姿态、边界框及其对应的实例标识等intermediate_keyframe_align

Transformer中的Self-Attention和Multi-Head Attention

2017 Google 在Computation and Language发表 当时主要针对于自然语言处理(之前的RNN模型记忆长度有限且无法并行化,只有计算完ti时刻后的数据才能计算ti+1时刻的数据,但Transformer都可以做到) 文章提出Self-Attention概念,在此基础上提出Multi-Head Atterntion 下面借鉴霹雳吧啦博主的视频进行学习: Se

大模型应用开发技术:Multi-Agent框架流程、源码及案例实战(一)

LlaMA 3 系列博客 基于 LlaMA 3 + LangGraph 在windows本地部署大模型 (一) 基于 LlaMA 3 + LangGraph 在windows本地部署大模型 (二) 基于 LlaMA 3 + LangGraph 在windows本地部署大模型 (三) 基于 LlaMA 3 + LangGraph 在windows本地部署大模型 (四) 基于 LlaMA 3