字节跳动端到端深度学习召回算法

2024-04-13 19:32

本文主要是介绍字节跳动端到端深度学习召回算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

706b9acb8e32dedfbadca3edfb76b29f.png

来源:DataFunTalk
本文约2600字,建议阅读5分钟
本文为你介绍字节跳动AML Team在大规模推荐中构建的可学习的索引结构。

[ 导读 ] 传统的召回算法一般基于双塔结构并加以approximately nearest neighbor search (ANN) 或者maximum inner productive search (MIPS),比如fast ball tree (FBT),hierarchical navigable small world (HNSW) 等。这些传统的算法embedding的训练目标和ANN的目标不一致,导致ANN的损失无法学习。目前比较著名的解决思路是构建一个tree-based model如TDM等。

我们今天将介绍字节跳动AML Team在大规模推荐中构建的可学习的索引结构,使得embedding的训练目标和索引结构的训练目标可以一致学习,达到良好的召回效果,它不仅局限于广告业务,在推荐和搜索业务中也有应用。

本文将从以下几方面展开:

  • Deep retrieval的核心模型

  • 如何训练structure model

  • 思考与讨论

  • 精选问答

01、Deep retrieval的核心

20051f7534662b22adef7b36166b88fd.png

如图所示我们可以根据DR的structure的KxD的矩阵构造出path。我们可以把这种path看成层级的聚类,每个path里面有很多的item,每个item也可以属于多个path,这样我们可以保留item的多元化信息。比如“同仁堂”可能是中药企业,也可以是一个相声,所以我们在搜索“同仁堂”对应的文章时,它既有可能在中医药的path出现,又有可能在相声中出现,达到了我们multi-path的效果。

1. 训练阶段的structure loss

0b6d82ea87aa9cedbe92154bf6b12aed.png

从上图的图例我们可以看到网络的结构,在第一层得到用户的embedding x 对应c_1的概率,之后path中的每一段都将用户embedding与之前的path embedding串联,最终得到path中当前code的条件概率。根据联合概率公式最终的概率为:

p(c|x)=p(c_1,c_2,c_3|x)=p(c_1|x)p(c_2|c_1,x)p(c_3|c_2,c_1,x)

在训练中已知正例用户embedding x和item id y, 如果我们知道y所在的path为π(y),则:

maxlogp(π(y)|x)=logp(π(y)1|x)+logp(π(y)2|π(y)1,x)+logp(π(y)3|π(y)1, π(y)2,x)

2. serving阶段的beam search

在serving阶段我们采用的是beam search的算法,具体如图所示:

4dd55cc78ef7ea7d88dcef5e8a1cd886.png

通过图中示意的方法,我们在每一层选概率最大的B个node向下传递,B是指beam size,通常选10个左右。最后我们选出B个path并merge其中的item。

02、如何训练structure model

1. EM算法

在DR中我们需要同时训练structure model的参数(记为θ),以及所有item到path的mapping(记为π),则训练目标为:

115039c9dbfd8b7947389e0c88a537e4.png

其中J是指J条path。我们想交替训练π和θ,于是采用EM算法来共同训练参数和mapping。最开始我们随机初始化θ和π并轮流进行E-step和M-step。

在E-step中,我们进行以下操作:一是可用任何基于梯度的优化算法优化θ,因为p_θ是可微的。二是对于每一个path c和item v计算它们的likelihood,记为s[v,c],也称为hidden score。

a362cfc6633dca386efa7f973dba19e8.png

其中假设对于任何一个item v出现了n次,对应n个用户xi,我们计算平均p(c|x, θ)的likelihood。由于可能的path有K^D个而我们不可能全部计算,所以我们将只选取beam search分数较高的path并记录其hidden score。

在M-step中,我们需要从hidden path和hidden score中更新π(v)。最直接的方法是对于每一个item我们选取hidden path中分数最高的path作为新的π,这样在EM算法的objective function显然会达到很高的分数,但是这样做有一个缺点,即导致很多item学到一个path,使得path过于集中,即有的path中有大量的item,有的path里面没有item,这样的结果送到下游任务时,下游的压力就无法控制。为此我们引入patch-size penalty,我们令f(x)=x^4,α是可调参数,有:

a383d89e16e34ee668e0a204d8b5c609.png

即减去所有path,每个path里面的item个数的四次方,如果当前path的item已经很多时,可以有效抑制item继续增加。

2. 在线EM算法

对于流式训练,我们设计了在线EM算法。在E-step中,我们将使用一个滑动的平均hidden score并且动态跟踪一个固定大小的hidden path set。在M-step中我们采取定时任务的方式,从Parameter Server里面读取每个item的hidden path和hidden score,然后运行上段所说的penalty 算法计算出新的true path并写入到PS。

3. 多任务学习

8f869d08d5a5eb9f33a9d638871e630e.png

现在的DR采用multi-task learning的机制,我们使用structure loss来训练structure model以及item-path mapping,同时我们也保留了点乘模型比如FFM,NN模型来训练user或者item embedding用作reranker。在serving过程中,我们通过beam search找出hidden path以及他们的item,先经过reranker经过初步筛选出固定条数的candidates,再精排。这样可以减缓了粗排和精排的压力,另一方面也可以控制出口条数。

03、思考与讨论

与传统的ANN相比,DR的聚类更注重用户侧行为而不是item本身,比如足球视频和汽车视频在ANN召回中可能不在同一类但是在DR中会在同一类。DR其实不是利用item embeddeing本身,而是利用user和item之间的信息,利用hidden score进行聚类,即虽然两个item本身并不相近,但是他们可能会被同一种user消费。所以DR中path里面item的diversity会比ANN高很多。

因此DR更偏向于偏重用户行为的应用场景,比如广告或者推荐,在搜索中DR通常会降低相关性。

DR的structure model目前只用了正例没有负例,负例只在rerank model中使用。而且structure model只使用了item ID embedding,没有使用item侧特征,item侧特征只在rerank model中使用。除此之外,DR的学习目标之来源于user,item pair没有体现相关性,如果能将相关性loss引入DR loss来端到端学习用户行为和相关性也许可以解决搜索遇到的问题。

04、精选问答

Q:同一个item是否属于同一个D?

A:首先,同一个item是可以属于多个path, 比如“同仁堂”既可以属于相声的path又可以属于中医药的path, 在不同层的code中,item也可以属于多个,比如属于1,2,3 code,1,2,4 code。着代表两个path在类方向是一致的。

Q:retrieval算法学到的聚类结构与U2U的算法的聚类结构有什么关系?

A:有可能有一定关系,聚类的结果更容易把相同用户消费的物品聚到一起。

Q:什么在检索的过程中要用beam search而不是全部检索完?

A:因为一般线上K是100到1000,D是3,如果全部检索则需要检索至少百万级别的path,是不符合实际的。所以我们需要一个方法选择比如top20的path,这个方法选择的top20和实际的top20非常相近,beam search这个方法满足了我们的需求。

Q:EM算法的收敛性是否有保证,在实际应用中是否会出现不收敛的情况?

A:在理论上是有一些paper论证过EM算法在哪些条件下可以收敛,这些理论上的假设理论性比较强。在实际情况下我们回去检验一些这些条件,有时候也可以直接看结果。在非流式训练的情况下,我们的算法经过3到5个M-step以后会收敛到稳定的值。在流式训练中,我们通过定时M-Step的方法,实际在5到10次M-step可以达到收敛,这些都是一些比较实践的方法。

Q:这个方法是召回用户的多兴趣那么也没有和其他的一些用transformer的方法做对比?

A:我们在做多兴趣召回,目前和transformer的关系还不是特别紧密,我们大概的做法也是生成不同的user embedding, 对不同的user embedding做 beam search, 对 search的结果进行一些merge并通过一些loss控制不同的embedding学习不同的兴趣,但是transformer我们还没有尝试。

Q:不加multi-task模型会变成什么样?M-step是否只做一次?

A:不加multi-task模型学到的path非常不平衡,虽然我们有penalty进行控制但是还是不够,这是因为不加multi-task负例没有被利用,item embedding也没有被充分利用,所以这就是为什么我们现阶段使用了multi-task。M-step需要做不止一次的,实际我们需要等到模型收敛,这至少要一两天的时间。

Q:这个模型的线上效果如何?

A:这个模型已经在字节跳动不少的产品上线,覆盖广告和推荐,海内外产品都有应用,效果还是很成功的。

今天的分享就到这里,谢谢大家。

编辑:于腾凯

校对:林亦霖f115704334fd80d8cbac732f7d6def7f.png

这篇关于字节跳动端到端深度学习召回算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/901019

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

字节面试 | 如何测试RocketMQ、RocketMQ?

字节面试:RocketMQ是怎么测试的呢? 答: 首先保证消息的消费正确、设计逆向用例,在验证消息内容为空等情况时的消费正确性; 推送大批量MQ,通过Admin控制台查看MQ消费的情况,是否出现消费假死、TPS是否正常等等问题。(上述都是临场发挥,但是RocketMQ真正的测试点,还真的需要探讨) 01 先了解RocketMQ 作为测试也是要简单了解RocketMQ。简单来说,就是一个分

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖

【数据结构】——原来排序算法搞懂这些就行,轻松拿捏

前言:快速排序的实现最重要的是找基准值,下面让我们来了解如何实现找基准值 基准值的注释:在快排的过程中,每一次我们要取一个元素作为枢纽值,以这个数字来将序列划分为两部分。 在此我们采用三数取中法,也就是取左端、中间、右端三个数,然后进行排序,将中间数作为枢纽值。 快速排序实现主框架: //快速排序 void QuickSort(int* arr, int left, int rig