特征交叉系列:FM和深度神经网络的结合,DeepFM原理简述和实践

本文主要是介绍特征交叉系列:FM和深度神经网络的结合,DeepFM原理简述和实践,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

从FM,FFM到DeepFM

在上两节中介绍了FM和FFM

这两种算法是推荐算法中经典的特征交叉结构,FM将特征交叉分解到底层属性,通过底层属性的点乘来刻画特征交叉的计算,而FFM引入特征域的概念,对不同的特征对所引用的底层属性进行隔离,避免导致多重特征交叉下,底层属性表征产生互相拉扯,导致表达矛盾。

在深度学习时代之前,FM结构是主流的推荐算法,而随着深度学习的到来,FM逐渐和DNN深度神经网络进行结合,即期望构建模型既可以拥有FM的二阶特征交互的学习能力,也能够像DNN那样能够学习特征间高阶的复杂关系,其中DeepFM是最经典的FM和DNN结合的例子。

DeepFM提出于2017年,由于网络中只有FM和简单DNN,因此易于快速实现作为业务场景的一个Baseline。


DeepFM网络结构解析

DeepFM的网络结构如下

DeepFM模型结构

模型结构中左侧部分是FM,右侧部分是DNN,底层的输入从一个业务特征Field转化为稀疏onehot的Sparse Feature特征,模型的前馈传播有三块计算网络:

    1. Sparse Feature特征直接进入左侧FM的一阶线性层,完成一个wx+b的操作返回一阶的结果,只有Sparse Feature有值的位置才会进行权重加和
    1. Sparse Feature特征进入Dense Embedding层进行稠密向量映射,映射后进入左侧FM,映射的结果作为FM的隐向量进行点乘操作得到FM的二阶输出
  • 3.Sparse Feature特征进入Dense Embedding层进行稠密向量映射,映射后进入右侧DNN,所有Field的映射结果进行拼接作为DNN的输入,经过2层DNN隐藏层输出结果

最终DeepFM的结果是三个计算流程的相加组合,注意FM的二阶和DNN的底层输入是共享的,共用了Dense Embedding层的结果,因此隐藏层的学习不仅要考虑适配FM的交叉,也要适配DNN的高阶复杂关系学习。


DeepFM在PyTorch下的实践

本次实践的数据集和上一篇[特征交叉系列:完全理解FM因子分解机原理和代码实战]一致,采用用户的购买记录流水作为训练数据,用户侧特征是年龄,性别,会员年限等离散特征,商品侧特征采用商品的二级类目,产地,品牌三个离散特征,随机构造负样本,一共有10个特征域,全部是离散特征,对于枚举值过多的特征采用hash分箱,得到一共72个特征。
PyTorch代码实现如下

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, TensorDatasetclass Linear(nn.Module):def __init__(self, feat_num):super(Linear, self).__init__()self.embedding = nn.Embedding(feat_num, 1)self.bias = nn.Parameter(torch.zeros(1))nn.init.xavier_normal_(self.embedding.weight.data)def forward(self, x):# [None, field_dim] => [None, field, 1] => [None, 1]out = self.embedding(x).sum(dim=1) + self.biasreturn outclass Embedding(nn.Module):def __init__(self, feat_num, k_dim):super(Embedding, self).__init__()self.embedding = nn.Embedding(feat_num, k_dim)nn.init.xavier_uniform_(self.embedding.weight.data)def forward(self, x):return self.embedding(x)class FM(nn.Module):def __init__(self):super(FM, self).__init__()def forward(self, x):square_of_sum = torch.sum(x, dim=1) ** 2sum_of_square = torch.sum(x ** 2, dim=1)ix = square_of_sum - sum_of_square# [None, 1]out = 0.5 * torch.sum(ix, dim=1, keepdim=True)return outclass DNN(nn.Module):def __init__(self, input_dim, fc_dims=(64, 16), dropout=0.1):super(DNN, self).__init__()layers = list()for fc_dim in fc_dims:layers.append(nn.Linear(input_dim, fc_dim))layers.append(nn.BatchNorm1d(fc_dim))layers.append(nn.ReLU())layers.append(nn.Dropout(p=dropout))input_dim = fc_dimlayers.append(nn.Linear(input_dim, 1))self.mlp = torch.nn.Sequential(*layers)def forward(self, x):return self.mlp(x)class Model(nn.Module):def __init__(self, field_num, feat_num, k_dim, fc_dims=(64, 16), dropout=0.1):super(Model, self).__init__()self.linear = Linear(feat_num=feat_num)self.embedding = Embedding(feat_num, k_dim)self.fm = FM()self.fc_input_dim = field_num * k_dimself.dnn = DNN(self.fc_input_dim, fc_dims, dropout)def forward(self, x):linear_out = self.linear(x)# [None, feat_size, k_dim]emb = self.embedding(x)fm_out = self.fm(emb)dnn_out = self.dnn(torch.reshape(emb, [-1, self.fc_input_dim]))out = torch.sigmoid(linear_out + fm_out + dnn_out)return out.squeeze(dim=1)

本例全部是离散分箱变量,所有有值的特征都是1,因此只要输入有值位置的索引即可,一条输入例如

>>> train_data[0]
Out[120]: (tensor([ 2, 10, 14, 18, 34, 39, 47, 51, 58, 64]), tensor(0))

其中x的长度10代表10个特征域,每个域的值是特征的全局位置索引,从0到71,一共72个特征。其中FM和DNN共用了Embedding对象。


DeepFM和FM,FFM模型效果对比

采用验证集的10次AUC不上升作为早停,FM,FFM,DeepFM的平均验证集AUC如下

FMFFMDeepFM
AUC0.6260.6300.631

DeepFM相比FM增加了DNN结构,AUC提升了0.5个百分点较为明显,而对比FFM,DeepFM也有略微提升,提升0.1个百分点。

最后的最后

感谢你们的阅读和喜欢,我收藏了很多技术干货,可以共享给喜欢我文章的朋友们,如果你肯花时间沉下心去学习,它们一定能帮到你。

因为这个行业不同于其他行业,知识体系实在是过于庞大,知识更新也非常快。作为一个普通人,无法全部学完,所以我们在提升技术的时候,首先需要明确一个目标,然后制定好完整的计划,同时找到好的学习方法,这样才能更快的提升自己。

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

一、全套AGI大模型学习路线

AI大模型时代的学习之旅:从基础到前沿,掌握人工智能的核心技能!

img

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

img

三、AI大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

img

四、AI大模型商业化落地方案

img

五、面试资料

我们学习AI大模型必然是想找到高薪的工作,下面这些面试题都是总结当前最新、最热、最高频的面试题,并且每道题都有详细的答案,面试前刷完这套面试题资料,小小offer,不在话下。
在这里插入图片描述

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

这篇关于特征交叉系列:FM和深度神经网络的结合,DeepFM原理简述和实践的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1034293

相关文章

防止Linux rm命令误操作的多场景防护方案与实践

《防止Linuxrm命令误操作的多场景防护方案与实践》在Linux系统中,rm命令是删除文件和目录的高效工具,但一旦误操作,如执行rm-rf/或rm-rf/*,极易导致系统数据灾难,本文针对不同场景... 目录引言理解 rm 命令及误操作风险rm 命令基础常见误操作案例防护方案使用 rm编程 别名及安全删除

C++统计函数执行时间的最佳实践

《C++统计函数执行时间的最佳实践》在软件开发过程中,性能分析是优化程序的重要环节,了解函数的执行时间分布对于识别性能瓶颈至关重要,本文将分享一个C++函数执行时间统计工具,希望对大家有所帮助... 目录前言工具特性核心设计1. 数据结构设计2. 单例模式管理器3. RAII自动计时使用方法基本用法高级用法

PHP应用中处理限流和API节流的最佳实践

《PHP应用中处理限流和API节流的最佳实践》限流和API节流对于确保Web应用程序的可靠性、安全性和可扩展性至关重要,本文将详细介绍PHP应用中处理限流和API节流的最佳实践,下面就来和小编一起学习... 目录限流的重要性在 php 中实施限流的最佳实践使用集中式存储进行状态管理(如 Redis)采用滑动

Spring 中的切面与事务结合使用完整示例

《Spring中的切面与事务结合使用完整示例》本文给大家介绍Spring中的切面与事务结合使用完整示例,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录 一、前置知识:Spring AOP 与 事务的关系 事务本质上就是一个“切面”二、核心组件三、完

ShardingProxy读写分离之原理、配置与实践过程

《ShardingProxy读写分离之原理、配置与实践过程》ShardingProxy是ApacheShardingSphere的数据库中间件,通过三层架构实现读写分离,解决高并发场景下数据库性能瓶... 目录一、ShardingProxy技术定位与读写分离核心价值1.1 技术定位1.2 读写分离核心价值二

深度解析Python中递归下降解析器的原理与实现

《深度解析Python中递归下降解析器的原理与实现》在编译器设计、配置文件处理和数据转换领域,递归下降解析器是最常用且最直观的解析技术,本文将详细介绍递归下降解析器的原理与实现,感兴趣的小伙伴可以跟随... 目录引言:解析器的核心价值一、递归下降解析器基础1.1 核心概念解析1.2 基本架构二、简单算术表达

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

深入浅出Spring中的@Autowired自动注入的工作原理及实践应用

《深入浅出Spring中的@Autowired自动注入的工作原理及实践应用》在Spring框架的学习旅程中,@Autowired无疑是一个高频出现却又让初学者头疼的注解,它看似简单,却蕴含着Sprin... 目录深入浅出Spring中的@Autowired:自动注入的奥秘什么是依赖注入?@Autowired

MySQL分库分表的实践示例

《MySQL分库分表的实践示例》MySQL分库分表适用于数据量大或并发压力高的场景,核心技术包括水平/垂直分片和分库,需应对分布式事务、跨库查询等挑战,通过中间件和解决方案实现,最佳实践为合理策略、备... 目录一、分库分表的触发条件1.1 数据量阈值1.2 并发压力二、分库分表的核心技术模块2.1 水平分

Java MCP 的鉴权深度解析

《JavaMCP的鉴权深度解析》文章介绍JavaMCP鉴权的实现方式,指出客户端可通过queryString、header或env传递鉴权信息,服务器端支持工具单独鉴权、过滤器集中鉴权及启动时鉴权... 目录一、MCP Client 侧(负责传递,比较简单)(1)常见的 mcpServers json 配置