特征交叉系列:DeepCross(DCN-V2)理论和实践

2024-06-06 17:12

本文主要是介绍特征交叉系列:DeepCross(DCN-V2)理论和实践,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DCN之前FM系列特征交叉思路总结

在之前的推荐算法特征交叉系列中,已经介绍了从FM,FFM,到PNN,DeepFM,NFM,AFM这一系列围绕特征交叉展开的算法,这些算法都是以FM为基石,在FM的基础上优化其二阶交互表达能力的策略,主要有三种形式:

    1. 引入DNN优化FM的非线性能力:PNN,DeepFM,NFM都是将DNN和FM结合,其中PNN,NFM采用串联结构,DeepFM采用并联结构
    1. 改变隐向量之间的交互方式:以PNN为代表,它改变了FM单一的内积方式,尝试外积,内外融合来增强二阶交叉的表达
    1. 改变交差结果聚合的计算方式:在FM中一对隐向量逐位相乘再相加为一个标量,所有标量相加作为最终的二阶表征,而在NFM中只做逐位相乘但不相加,在IPNN中虽然也是隐向量逐位相乘再相加,但是所有标量没有直接相加而是平铺开来作为二阶表征,在AFM中更是采用注意力进行加权求和。

针对第一点,之前章节的代码实践都能证明引入DNN相比于FM有显著提升,DNN弥补了FM的非线性能力隐式高阶交叉能力。针对第二点外积和内外融合相比于内积有提升但是会带来更多参数的学习,详情见[特征交叉系列:PNN向量积模型理论和实践,FM和DNN的串联],而第三点不论怎么聚合,隐向量的逐位相乘都保留了下来,因此逐位相乘是FM特征交叉的精髓,是能够让两特征进行交叉表达的核心。

在DCN出来之前,以上算法都是没有改变FM这种范式底层结构,仅仅是对FM的某些细节做调整,而不论怎么调整也只能完成二阶交叉,对于高阶交叉只能寄希望于DNN在FM二阶的基础上能够隐式的学习到,因此对于高阶交叉这系列算法的表达能力是不充分的,引出本节介绍的算法DCN-V2,它彻底摆脱了FM的范式,通过一个递归的设计实现了任意有限阶的显示高阶交叉


DCN-V2模型结构介绍

DCN(Deep & Cross Network,深度交叉网络),有两个版本分别时DCN V1和DCN V2,DCN V1也叫DCN-V发表于2017年,DCN V2也叫DCN-M发表于2020年,两篇论文都来自Google公司,其中DCN-V1存在理论缺陷本节不做介绍,直接进入DCN-V2,两者的核心区别是前者使用向量Vector作为交叉层的学习权重,而后者采用矩阵Matrix,这也是命名DCN-V和DCN-M的由来。
DCN-V2的网络结构如下

DCN-V2

左边Stack和右边Parallel分别是DCN中交叉网络和DNN组合的策略,Stack代表串联,Parallel代表并联,相当于肯定了之前基于FM和DNN并联或者串联的算法,认为这两种形式都可以,具体用那种需要根据数据学习情况而定。
最底层是输入层,采用连续变量和稀疏变量的embedding拼接的方式,拼接的结果为一个向量,该向量作为一个整体进入Cross Network交叉层,交叉层每做一层交叉会输出一个新的向量,该新向量和输入的向量维度相同,与此同时每次交叉操作都需要引入一次原始输入,由于输入的向量维度和输入相同,因此这个操作可以无限递归下去,即Cross Network可以完成任意阶数的交叉操作。
Cross Network输出的结果作为高阶交叉的最终表征,如果是Stack策略则输入给DNN继续学习,如果是Parallel策略则和DNN的输出进行合并,最终映射到标签y上进行损失迭代学习。


DCN-V2的交叉网络为什么可以完成高阶交叉

上面简述了DCN-V2的整体结构,除了Cross Network其他和NFM,DeepFM没啥本质区别,现在开始深入Cross Network部分,它的计算公式如下

交叉计算公式

论文中的可视化表达地更加容易理解

交叉计算公式可视化

前面也提到DCN-V2的交叉部分是个可以无限递归的过程,i表示第i阶(本层),i+1表示第i+1阶(下一层),则下一层交叉的结果等于本层的结果经过一个M矩阵和b向量偏置,和原始输入的逐位相乘,再加上本层的结果,下面对于等式右边的元素,按照从左到右的位置分别讲解下它们在这个操作中的作用

    1. X0原始输入:它代表特征的一阶结果,在交叉网络中每一层都会引用它,每引用一次本层就会再这基础上再上一阶,X0是一个把DCN推上更高阶的增量因子
    1. W变换矩阵:该矩阵和本层的结果xi矩阵相乘,使得Xi每个元素都变成了Xi内所有元素的一个加权求和线性映射的结果,再配合上和X0的哈达马积逐位相乘就实现了类似FM的精髓交叉操作,而W的作用就是挑选出每个Xi位置下需要交叉的内容,将原始的Xi映射为最合适和X0做逐位相乘的形态
    1. b偏置向量:W对每个Xi位置上是个纯线性映射,因此加入偏置项b
    1. Xi本层输入:类似残差连接,是一种兜底策略,随着网络层数的增高使得模型的表达能力至少不弱于上一层,缓解梯度弥散和梯度爆炸的情况。

这个计算公式的核心是权重矩阵W,可以举个简单例子侧面证明W的有效性,比如交叉网络只做一层,这个计算公式能不能复现出FM?举例我们有三个特征,经过embedding之后分别有三个向量w1,w2,w3,在FM中会采用类似for循环的方式形成field_num*(field_num-1)对向量两两逐位相乘,分别是w1×w2,w2×w3,w1×w3,而在DCN-V2中输入被拼接在了一起w1||w2||w3,在一层中是X0和X0进行交叉,因此只要W能把第二个X0变换为w2||w3||w1,就可以复现FM的效果,如图

DCN第一层复现FM

因此只要设计一个矩阵W实现如图所示的把w1||w2||w3转化为w2||w3||w1即可,这个W当然很容易就可以设计出来,W的每一行都设计成一个onehot向量即可,只在红线的开始的那个位置和W矩阵相乘的对应位置处设置为1,其他位置全部为0即可。因此在DCN的第一层,W矩阵可以完全恢复出FM的效果,相当于在DCN-V2的第一层完成了二阶交叉,而往后的每一层都在之前基础上再做一次二阶交叉,因此DCN-V2可以通过这种递归的方式实现任意高阶的交叉。
本质上,DCN-V2采用了一个wx+b操作对每层的表征x做一个变换,使得它更好地和一阶输入x0进行哈达马积,而哈达马积就是在做特征交叉


再回过头来看DCN的高阶交叉思路

在明白交叉层的公式之后,我们回过头来猜测一下作者这么设计的脑回路,毕竟DCN彻底从FM这种field_num*(field_num-1)的for循环中跳出来了,非常具有创新性。
我们能不能把所有特征嵌入向量拼接在一起看成一个整体,让这个整体自己去做各种复杂的交叉操作,而不是采用FM这个手动两两的策略。对这个整体的输入,我们能不能要设计一个网络能对原始输入做任意高阶交叉,能不能设计一个变换函数,使得n+1阶交叉从n阶表达直接变换出来,最好能满足以下两个条件:

    1. 这个变换函数最好有通用固定的结构,每层可以复用,但是内部参数每层可以不同
    1. 每层交叉的结果和原始输入维度相同,因此这种变换可以一直持续延伸下去

如果这两个条件能够满足,就可以写一个for循环把原始输入无限变换再变换到任意高阶交叉,这是多么丝滑的体验。这种思路在深度学习中并不少见,图神经网络GNN也是设计了一个固定的信息传播聚合的公式,使得GNN可以让中心节点聚合任意跳的邻居节点信息作为自身的表征。
前文也提到DCN是将输入的向量拼接成成一个整体,而FM系是做半三角的循环拿到两两向量作为一对,紧接着引出DCN和FM的交叉性质的不同,即DCN是采用bit-wise交叉,而FM系列都是vector-wise的交叉,所谓bit-wise就是输入中不再区分不同特征field的概念,所有embeding的每个元素当成一个特征单元和其他任意元素单元进行交叉,即同一个field下会内部元素进行交叉,而vector-wise是一个field的向量和另一个field的向量进行交叉,field内部元素之间不进行交叉。


DCN-V2在PyTorch下的实践

本次实践的数据集和上一篇[特征交叉系列:完全理解FM因子分解机原理和代码实战]一致,采用用户的购买记录流水作为训练数据,用户侧特征是年龄,性别,会员年限等离散特征,商品侧特征采用商品的二级类目,产地,品牌三个离散特征,随机构造负样本,一共有10个特征域,全部是离散特征,对于枚举值过多的特征采用hash分箱,得到一共72个特征。
DCN-V2的PyTorch代码实现如下

class Embedding(nn.Module):def __init__(self, feat_num, emb_num):super(Embedding, self).__init__()self.embedding = nn.Embedding(feat_num, emb_num)nn.init.xavier_normal_(self.embedding.weight.data)def forward(self, x):# [None, filed_num] => [None, filed_num, emb_num] => [None, filed_num * emb_num]return self.embedding(x).flatten(1)class DNN(nn.Module):def __init__(self, input_num, hidden_nums, dropout=0.1):super(DNN, self).__init__()layers = []input_num = input_numfor hidden_num in hidden_nums:layers.append(nn.Linear(input_num, hidden_num))layers.append(nn.BatchNorm1d(hidden_num))layers.append(nn.ReLU())layers.append(nn.Dropout(p=dropout))input_num = hidden_numself.mlp = nn.Sequential(*layers)for layer in self.mlp:if isinstance(layer, nn.Linear):nn.init.xavier_normal_(layer.weight.data)def forward(self, x):return self.mlp(x)class CrossCell(nn.Module):"""一个交叉单元"""def __init__(self, input_num):super(CrossCell, self).__init__()self.w = nn.Parameter(torch.randn(input_num, input_num))self.b = nn.Parameter(torch.randn(input_num, 1))nn.init.xavier_normal_(self.w.data)def forward(self, x0, xi):# [None, emb_num] => [None, emb_num, 1]xi = xi.unsqueeze(2)x0 = x0.unsqueeze(2)# [input_num, input_num] * [None, emb_num, 1] => [None, emb_num, 1] => [None, emb_num, 1] => [None, emb_num, 1]xii = (torch.matmul(self.w, xi) + self.b) * x0 + xireturn xii.squeeze(2)class CrossNet(nn.Module):def __init__(self, order_num, input_num):super(CrossNet, self).__init__()self.order = order_numself.cell_list = nn.ModuleList([CrossCell(input_num) for i in range(order_num)])def forward(self, x0):xi = x0for i in range(self.order):xi = self.cell_list[i](x0=x0, xi=xi)return xiclass DCN(nn.Module):def __init__(self, field_num, feat_dim, emb_num, order_num, dropout=0.1, method='parallel',hidden_nums=(128, 64, 32)):super(DCN, self).__init__()input_num = field_num * emb_numself.embedding = Embedding(feat_num=feat_dim, emb_num=emb_num)self.dnn = DNN(input_num=input_num, hidden_nums=hidden_nums, dropout=dropout)self.cross_net = CrossNet(order_num=order_num, input_num=input_num)if method not in ('parallel', 'stacked'):raise ValueError('unknown combine type: ' + method)self.method = methodlinear_dim = hidden_nums[-1]if self.method == 'parallel':linear_dim = linear_dim + input_numself.linear = nn.Linear(linear_dim, 1)nn.init.xavier_normal_(self.linear.weight.data)def forward(self, x):emb = self.embedding(x)  # [None, field * emb_num]cross_out = self.cross_net(emb)  # [None, input_num]if self.method == 'parallel':dnn_out = self.dnn(emb)  # [None, input_num]out = torch.concat([cross_out, dnn_out], dim=1)else:out = self.dnn(cross_out)  # [None, input_num]out = self.linear(out)return torch.sigmoid(out).squeeze(dim=1)

其中CrossCell子模块是一层交叉层,在CrossNet中指定阶数创建多个CrossCell的列表。在DCN主模块中设定了stacked,parallel两种策略。
本例全部是离散分箱变量,所有有值的特征都是1,因此只要输入有值位置的索引即可,一条输入例如

>>> train_data[0]
Out[120]: (tensor([ 2, 10, 14, 18, 34, 39, 47, 51, 58, 64]), tensor(0))

其中x的长度10代表10个特征域,每个域的值是特征的全局位置索引,从0到71,一共72个特征。


DCN-V2调参效果对比

对阶数(order_num)和融合策略(method)这两个参数进行调参,分别尝试1~4层交叉层,stacked和parallel两种策略,采用10次验证集AUC不上升作为早停条件,验证集的平均AUC如下

DCN调参AUC并行parallel串行stacked
1层交叉(2阶)0.63310.6325
2层交叉(3阶)0.63340.6320
3层交叉(4阶)0.63480.6344
4层交叉(5阶)0.63300.6326

结论是不论parallel还是stacked,在当前数据集下DCN使用3层交叉达到最优AUC,其中parallel略高一点,再往上对叠交叉层AUC会下降。
再对比一下之前文章中实践的FM,FFM,PNN等一系列算法,验证集AUC和参数规模如下

算法AUC参数量
FM0.6274361
FFM0.63172953
IPNN0.632215553
OPNN0.632627073
PNN*0.634229953
DeepFM0.632212746
NFM0.632910186
DCN-parallel-30.6348110017
DCN-stacked-30.6344109857

image.png

DCN都取得了目前为止的SOTA结果,其次是PNN*,发现随着模型参数量的增大,模型的预测效果也随着提升,DCN的参数规模原高于其他FM系的算法,在本例中即使只用1层交叉DCN的参数规模也能达到5万(是FM参数量的100倍以上)。从结果看,引入高阶交叉的DCN-V2确实比只有二阶交叉的FM系高出一筹,但是模型复杂度也大幅提升。

最后的最后

感谢你们的阅读和喜欢,我收藏了很多技术干货,可以共享给喜欢我文章的朋友们,如果你肯花时间沉下心去学习,它们一定能帮到你。

因为这个行业不同于其他行业,知识体系实在是过于庞大,知识更新也非常快。作为一个普通人,无法全部学完,所以我们在提升技术的时候,首先需要明确一个目标,然后制定好完整的计划,同时找到好的学习方法,这样才能更快的提升自己。

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

一、全套AGI大模型学习路线

AI大模型时代的学习之旅:从基础到前沿,掌握人工智能的核心技能!

img

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

img

三、AI大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

img

四、AI大模型商业化落地方案

img

五、面试资料

我们学习AI大模型必然是想找到高薪的工作,下面这些面试题都是总结当前最新、最热、最高频的面试题,并且每道题都有详细的答案,面试前刷完这套面试题资料,小小offer,不在话下。
在这里插入图片描述

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

这篇关于特征交叉系列:DeepCross(DCN-V2)理论和实践的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1036719

相关文章

在C#中获取端口号与系统信息的高效实践

《在C#中获取端口号与系统信息的高效实践》在现代软件开发中,尤其是系统管理、运维、监控和性能优化等场景中,了解计算机硬件和网络的状态至关重要,C#作为一种广泛应用的编程语言,提供了丰富的API来帮助开... 目录引言1. 获取端口号信息1.1 获取活动的 TCP 和 UDP 连接说明:应用场景:2. 获取硬

Java内存泄漏问题的排查、优化与最佳实践

《Java内存泄漏问题的排查、优化与最佳实践》在Java开发中,内存泄漏是一个常见且令人头疼的问题,内存泄漏指的是程序在运行过程中,已经不再使用的对象没有被及时释放,从而导致内存占用不断增加,最终... 目录引言1. 什么是内存泄漏?常见的内存泄漏情况2. 如何排查 Java 中的内存泄漏?2.1 使用 J

Linux中Curl参数详解实践应用

《Linux中Curl参数详解实践应用》在现代网络开发和运维工作中,curl命令是一个不可或缺的工具,它是一个利用URL语法在命令行下工作的文件传输工具,支持多种协议,如HTTP、HTTPS、FTP等... 目录引言一、基础请求参数1. -X 或 --request2. -d 或 --data3. -H 或

Docker集成CI/CD的项目实践

《Docker集成CI/CD的项目实践》本文主要介绍了Docker集成CI/CD的项目实践,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录一、引言1.1 什么是 CI/CD?1.2 docker 在 CI/CD 中的作用二、Docke

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

基于MySQL Binlog的Elasticsearch数据同步实践

一、为什么要做 随着马蜂窝的逐渐发展,我们的业务数据越来越多,单纯使用 MySQL 已经不能满足我们的数据查询需求,例如对于商品、订单等数据的多维度检索。 使用 Elasticsearch 存储业务数据可以很好的解决我们业务中的搜索需求。而数据进行异构存储后,随之而来的就是数据同步的问题。 二、现有方法及问题 对于数据同步,我们目前的解决方案是建立数据中间表。把需要检索的业务数据,统一放到一张M

2024年流动式起重机司机证模拟考试题库及流动式起重机司机理论考试试题

题库来源:安全生产模拟考试一点通公众号小程序 2024年流动式起重机司机证模拟考试题库及流动式起重机司机理论考试试题是由安全生产模拟考试一点通提供,流动式起重机司机证模拟考试题库是根据流动式起重机司机最新版教材,流动式起重机司机大纲整理而成(含2024年流动式起重机司机证模拟考试题库及流动式起重机司机理论考试试题参考答案和部分工种参考解析),掌握本资料和学校方法,考试容易。流动式起重机司机考试技

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识