推荐算法实战项目:AFM 原理以及案例实战(附完整 Python 代码)

2023-10-09 00:10

本文主要是介绍推荐算法实战项目:AFM 原理以及案例实战(附完整 Python 代码),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文要介绍的是由浙江大学联合新加坡国立大学提出的AFM模型。通过名字也可以看出,此模型又是基于FM模型的改进,其中A代表”Attention“,即AFM模型实际上是在FM模型中引入了注意力机制改进得来的。

之所以要在FM模型中引入注意力机制,是因为传统的FM模型对所有的交叉特征都平等对待,即每个交叉特征的权重都是相同的(都为1)。而在实际应用中,不同交叉特征的重要程度往往是不一样的。

如果”一视同仁“地对待所有的交叉特征,不考虑不同特征对结果的影响程度,事实上消解了大量有价值的信息。

AFM 论文地址:这里

推荐系统中的注意力机制

这里再举个例子,说明一下注意力机制是如何在推荐系统中派上用场的。注意力机制基于假设——不同的交叉特征对结果的影响程度不同,以更直观的业务场景为例,用户对不同交叉特征的关注程度应该是不同的。

举例来说,如果应用场景是预测一位男性用户是否会购买一款键盘的可能性,那么**”性别=男”“购买历史包含鼠标“这一交叉特征,很可能比”性别=男”“年龄=30“**这一交叉特征重要,模型应该投入更多的”注意力“在前面的特征上。

正因如此,将注意力机制引入推荐系统中也显得理所当然了。

模型

在介绍AFM模型之前,先给出FM模型的方程:

FM模型方程

Pair-wise 交互层

Pair-wise 每个交叉向量都是通过对两个不同的向量进行内积来计算的。可以通过以下公式来描述:

Attention-based Pooling层

下面看一下作者是如何将注意力机制加入到FM模型中去的,具体如下:

作者提出了通过MLP来参数化注意力分数,作者称之为”注意力网络“,其定义如下:

AFM模型

下面给出完整的AFM框架图:

AFM框架

AFM模型的整体方程为:

完整源码&技术交流

技术要学会分享、交流,不建议闭门造车。一个人走的很快、一堆人可以走的更远。

文章中的完整源码、资料、数据、技术交流提升, 均可加知识星球交流群获取,群友已超过2000人,添加时切记的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、添加微信号:mlc2060,备注:来自 获取推荐资料
方式②、微信搜索公众号:机器学习社区,后台回复:推荐资料

代码实践

模型部分:

import torch
import torch.nn as nn
from BaseModel.basemodel import BaseModelclass AFM(BaseModel):def __init__(self, config, dense_features_cols, sparse_features_cols):super(AFM, self).__init__(config)self.num_fields = config['num_fields']self.embed_dim = config['embed_dim']self.l2_reg_w = config['l2_reg_w']# 稠密和稀疏特征的数量self.num_dense_feature = dense_features_cols.__len__()self.num_sparse_feature = sparse_features_cols.__len__()# AFM的线性部分,对应 ∑W_i*X_i, 这里包含了稠密和稀疏特征self.linear_model = nn.Linear(self.num_dense_feature + self.num_sparse_feature, 1)# AFM的Embedding层,只是针对稀疏特征,有待改进。self.embedding_layers = nn.ModuleList([nn.Embedding(num_embeddings=feat_dim, embedding_dim=config['embed_dim'])for feat_dim in sparse_features_cols])# Attention Networkself.attention = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=True)self.projection = torch.nn.Linear(self.embed_dim, 1, bias=False)self.attention_dropout = nn.Dropout(config['dropout_rate'])# prediction layerself.predict_layer = torch.nn.Linear(self.embed_dim, 1)def forward(self, x):# 先区分出稀疏特征和稠密特征,这里是按照列来划分的,即所有的行都要进行筛选dense_input, sparse_inputs = x[:, :self.num_dense_feature], x[:, self.num_dense_feature:]sparse_inputs = sparse_inputs.long()# 求出线性部分linear_logit = self.linear_model(x)# 求出稀疏特征的embedding向量sparse_embeds = [self.embedding_layers[i](sparse_inputs[:, i]) for i in range(sparse_inputs.shape[1])]sparse_embeds = torch.cat(sparse_embeds, axis=-1)sparse_embeds = sparse_embeds.view(-1, self.num_sparse_feature, self.embed_dim)# calculate inner productrow, col = list(), list()for i in range(self.num_fields - 1):for j in range(i + 1, self.num_fields):row.append(i), col.append(j)p, q = sparse_embeds[:, row], sparse_embeds[:, col]inner_product = p * q# 通过Attention network得到注意力分数attention_scores = torch.relu(self.attention(inner_product))attention_scores = torch.softmax(self.projection(attention_scores), dim=1)# dim=1 按行求和attention_output = torch.sum(attention_scores * inner_product, dim=1)attention_output = self.attention_dropout(attention_output)# Prodict Layer# for regression problem with MSELossy_pred = self.predict_layer(attention_output) + linear_logit# for classifier problem with LogLoss# y_pred = torch.sigmoid(y_pred)return y_pred

在criteo数据集上测试,测试代码如下:

import torch
from AFM.network import AFM
from DeepCrossing.trainer import Trainer
import torch.utils.data as Data
from Utils.criteo_loader import getTestData, getTrainDataafm_config = \
{'num_fields': 26, # 这里配置的只是稀疏特征的个数'embed_dim': 8, # 用于控制稀疏特征经过Embedding层后的稠密特征大小'seed': 1024,'l2_reg_w': 0.001,'dropout_rate': 0.1,'num_epoch': 200,'batch_size': 64,'lr': 1e-3,'l2_regularization': 1e-4,'device_id': 0,'use_cuda': False,'train_file': '../Data/criteo/processed_data/train_set.csv','fea_file': '../Data/criteo/processed_data/fea_col.npy','validate_file': '../Data/criteo/processed_data/val_set.csv','test_file': '../Data/criteo/processed_data/test_set.csv','model_name': '../TrainedModels/AFM.model'
}if __name__ == "__main__":##################################################################################### AFM 模型####################################################################################training_data, training_label, dense_features_col, sparse_features_col = getTrainData(afm_config['train_file'], afm_config['fea_file'])train_dataset = Data.TensorDataset(torch.tensor(training_data).float(), torch.tensor(training_label).float())test_data = getTestData(afm_config['test_file'])test_dataset = Data.TensorDataset(torch.tensor(test_data).float())afm = AFM(afm_config, dense_features_cols=dense_features_col, sparse_features_cols=sparse_features_col)##################################################################################### 模型训练阶段##################################################################################### # 实例化模型训练器trainer = Trainer(model=afm, config=afm_config)# 训练trainer.train(train_dataset)# 保存模型trainer.save()##################################################################################### 模型测试阶段####################################################################################afm.eval()if afm_config['use_cuda']:afm.loadModel(map_location=lambda storage, loc: storage.cuda(afm_config['device_id']))afm = afm.cuda()else:afm.loadModel(map_location=torch.device('cpu'))y_pred_probs = afm(torch.tensor(test_data).float())y_pred = torch.where(y_pred_probs>0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))print("Test Data CTR Predict...\n ", y_pred.view(-1))

点击率预估结果如下(预测用户会点击输出为1,反之为0):

这篇关于推荐算法实战项目:AFM 原理以及案例实战(附完整 Python 代码)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/168915

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

这15个Vue指令,让你的项目开发爽到爆

1. V-Hotkey 仓库地址: github.com/Dafrok/v-ho… Demo: 戳这里 https://dafrok.github.io/v-hotkey 安装: npm install --save v-hotkey 这个指令可以给组件绑定一个或多个快捷键。你想要通过按下 Escape 键后隐藏某个组件,按住 Control 和回车键再显示它吗?小菜一碟: <template

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

Hadoop企业开发案例调优场景

需求 (1)需求:从1G数据中,统计每个单词出现次数。服务器3台,每台配置4G内存,4核CPU,4线程。 (2)需求分析: 1G / 128m = 8个MapTask;1个ReduceTask;1个mrAppMaster 平均每个节点运行10个 / 3台 ≈ 3个任务(4    3    3) HDFS参数调优 (1)修改:hadoop-env.sh export HDFS_NAMENOD

如何用Docker运行Django项目

本章教程,介绍如何用Docker创建一个Django,并运行能够访问。 一、拉取镜像 这里我们使用python3.11版本的docker镜像 docker pull python:3.11 二、运行容器 这里我们将容器内部的8080端口,映射到宿主机的80端口上。 docker run -itd --name python311 -p

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于