Deeplearning4j 实战 (22):基于DSSM的语义匹配建模

2024-01-01 22:36

本文主要是介绍Deeplearning4j 实战 (22):基于DSSM的语义匹配建模,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Deeplearning4j 实战 (22):基于DSSM的语义匹配建模

Eclipse Deeplearning4j GitChat课程:Deeplearning4j 快速入门_专栏
Eclipse Deeplearning4j 系列博客:万宫玺的专栏_wangongxi_CSDN博客
Eclipse Deeplearning4j Github:https://github.com/eclipse/deeplearning4j
Eclipse Deeplearning4j 社区:https://community.konduit.ai/

DSSM是微软在2013年提出的,最早用于搜索引擎语义召回的双塔模型。目前在工业界也广泛用于推荐召回、搜索相关性排序、语义召回等环节。DSSM是一个轻量级模型,在线上serving的时候,可以通过对query向量和doc向量计算内积,得到的相似值用来衡量query和doc的相似度,从而进行进一步的排序。下面就分别从DSSM模型结构、基于DL4J的DSSM建模、对开源数据集LCQMC的建模等几个环节来介绍如何使用DSSM模型。当然,由于DSSM模型的论文发表时间较早,发表时给出的模型结构比较简单,在我们具体实现的时候,会做一些调整,具体在介绍模型搭建的部分会提到。

1. DSSM模型简述

在论文中,query和doc分别通过各自独立的神经网络映射成一个语义向量。需要注意的是,原论文中doc是一个包含正样本和负样本的集合。正样本取1个,负样本取4个。论文中有提到,正样本是搜素后被点击的样本,负样本则是随机选取的搜索未被点击的样本集合。通过分别计算query的语义向量和正负doc样本的语义向量的余弦相似度,再通过softmax函数得到正负样本的概率分布后,和label计算交叉熵损失。这就是DSSM模型的大致的idea。下面先看下论文中对于DSSM描绘的架构图:
在这里插入图片描述
通过模型架构图可以看到,论文中是使用最简单的MLP对输入进行映射。这里需要提一下word hashing的操作。由于2013年时候word embedding技术还不是较广泛的使用,因此论文中的word hashing是在n-gram语言模型的基础上,通过hash操作将接近50W的词表计算每个词的索引值。这在当时是一种比较高效的做法,目前由于硬件的进步以及embedding技术的进一步成熟,可以直接使用预训练的embedding向量或者做端到端的建模。因此,在第三部分中构建DSSM模型的过程中,我们也是使用的端到端的方案。
在这里插入图片描述

上面这张截图中模型训练的有关描述。就像在本节开始时候提到的,通过softmax计算query和每个doc的余弦相似度的值归一化概率分布。由于softmax函数与cosine相似度的一致性,因此相似度越高的query-doc pair,其softmax值也会越接近于1。在损失函数部分,使用的是经典的log loss。这部分没啥说的。
另外需要说明的是,从ranking loss的角度,论文中的loss应当属于list-wise loss。当然,如果将负样本减少到一个或者doc集合中只有一个正样本或负样本(softmax更改为sigmoid函数),那就退化成pair-wise loss或者point-wise loss。为了方便起见,在第三部分的建模过程中,我们会使用point-wise loss。
对于搜索场景来说,双塔的输入分别是query和doc。对于推荐场景来说,双塔的输入可以是user和item或者item和item,用于U2I的召回或者I2I的召回。

2. LCQMC数据集

LCQMC是哈工大和阿里共同开源的用于QA的数据集,详情可参见论文。下载链接为:地址。压缩包中共有三个文件,三个文件都是以制表符作为分隔符。我们先来看下用于训练的部分数据的截图:
在这里插入图片描述
文件中有三列。最后一列用1或者0来代表 text_a 和 text_b两列文本的是否相关。如果把text_a列文本看作是query,那text_b列可以看作是doc。用于验证的文件中的内容也和训练文件中的数据格式相似,这里就不做另外截图了。

最后提一下,训练样本数量是:238767,验证的样本数量是:12501。

3. 基于DL4J的DSSM模型构建

在第一部分中,我们提到DSSM的论文中双塔内部是使用MLP结构。考虑到MLP结构的单一性,我们使用Embedding+LSTM+MLP的结构作为双塔的内部结构。虽然query和doc对应的塔结构相同,但是不做参数的共享。另外,由于LCQMC数据集中label是1或者0,因此我们将DSSM的输出层改为sigmoid + binary cross entropy loss。具体我们先给出代码片段:

private static ComputationGraph getDSSM(final int QUERY_VOCAB_SIZE, final int DOC_VOCAB_SIZE, final int VECTOR_SIZE) {ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Adam(5 * 1e-3)).weightInit(WeightInit.XAVIER).seed(12345L).graphBuilder().addInputs("query", "doc").setInputTypes(InputType.recurrent(QUERY_VOCAB_SIZE), InputType.recurrent(DOC_VOCAB_SIZE)).addLayer("query-embedding", new EmbeddingSequenceLayer.Builder().nIn(QUERY_VOCAB_SIZE + 1).nOut(VECTOR_SIZE).build(), "query").addLayer("query-embedding-lstm", new LSTM.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE).activation(Activation.TANH).build(), "query-embedding").addLayer("doc-embedding", new EmbeddingSequenceLayer.Builder().nIn(DOC_VOCAB_SIZE + 1).nOut(VECTOR_SIZE).build(), "doc").addLayer("doc-embedding-lstm", new LSTM.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE).activation(Activation.TANH).build(), "doc-embedding").addVertex("query-embedding-lstm-last-output", new LastTimeStepVertex("query"), "query-embedding-lstm").addVertex("doc-embedding-lstm-last-output", new LastTimeStepVertex("doc"), "doc-embedding-lstm").addLayer("query-output", new DenseLayer.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE / 2).activation(Activation.LEAKYRELU).build(), "query-embedding-lstm-last-output").addLayer("doc-output", new DenseLayer.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE / 2).activation(Activation.LEAKYRELU).build(), "doc-embedding-lstm-last-output").addVertex("query-output-l2-norm", new L2NormalizeVertex(), "query-output").addVertex("doc-output-l2-norm", new L2NormalizeVertex(), "doc-output").addVertex("cosing-similar", new ElementWiseVertex(ElementWiseVertex.Op.Product), "query-output-l2-norm", "doc-output-l2-norm").addLayer("out", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.XENT)	//bce.nIn(VECTOR_SIZE / 2).nOut(1).activation(Activation.SIGMOID).build(), "cosing-similar").setOutputs("out").build();ComputationGraph net = new ComputationGraph(conf);net.setListeners(new ScoreIterationListener(1));net.init();return net;
}

由于存在两个输入,因此使用DL4J中的ComputationGraph。这里需要说明的有几点:

  • LastTimeStepVertex的作用:获取LSTM最后一个time step输出的张量
  • L2NormalizeVertex的作用:L2归一化,将query和doc的向量转化为单位向量
  • ElementWiseVertex的作用:通过设置Op为点积,实际为计算query和doc单位向量的内积,因此L2NormalizeVertex + ElementWiseVertex联合起来的作用是计算向量间的余弦相似度值
  • 输出端使用sigmoid + bce 作point-wise的损失函数
    在这里插入图片描述

上面的截图中通过summary接口打印的模型结构和待训练参数。可见待训练参数68W。

另外,对于该静态方法,输入的几个参数QUERY_VOCAB_SIZE,DOC_VOCAB_SIZE,VECTOR_SIZE分别代表LCQMC数据集中text_a的词表大小和text_b的词表大小,以及词向量的大小。

需要指出的是,在第四部分进行建模的操作中,我们使用中文单字作为query和doc的最小粒度特征,而不做分词的处理。

4. DSSM模型训练和评估

首先介绍下数据处理的部分:

  • 读取训练文件,构建中文单字和单字的索引,存储在map结构中。同时记录最长的文本长度,用于后续的padding操作。
  • 再次读取文件,对每条记录构建MultiDataSet对象,并存储在LinkedList对象中。MultiDataSet对象中会存储query和doc作为输入,label作为输出,此外还有query和doc的mask张量,用于统一变长文本的处理。

我们看下具体的实现逻辑:

class DataSetInfo{public Map<String,Integer> queryDict = new TreeMap<>();public Map<String,Integer> docDict = new TreeMap<>();public int queryMaxLen = 0;public int docMaxLen = 0;
}private static DataSetInfo preprocess(String filePath) {DataSetInfo info = new DataSetInfo();try(BufferedReader br = Files.newReader(new File(filePath), Charset.forName("UTF-8"))){String line = null;int lineIndex = 0;while( (line = br.readLine()) != null ) {if( lineIndex == 0 ) {lineIndex++;continue;}String[] splits = line.split("\t");if( null == splits || splits.length != 3 )continue;String query = splits[0];String doc = splits[1];if( query != null && query.length() > 0 ) {info.queryMaxLen = Math.max(query.length(), info.queryMaxLen);for( char c : query.toCharArray() ) {String charStr = String.valueOf(c);if( !info.queryDict.containsKey(charStr) ) {int curIndex = info.queryDict.size();info.queryDict.put(charStr, curIndex);}}}if( doc != null && doc.length() > 0 ) {info.docMaxLen = Math.max(doc.length(), info.docMaxLen);for( char c : doc.toCharArray() ) {String charStr = String.valueOf(c);if( !info.docDict.containsKey(charStr) ) {int curIndex = info.docDict.size();info.docDict.put(charStr, curIndex);}}}}}catch(Exception ex) {ex.printStackTrace();}finally {int curIndex = info.queryDict.size();info.queryDict.put("UNK", curIndex);//curIndex = info.docDict.size();info.docDict.put("UNK", curIndex);}return info;
}

这部分处理逻辑比较清晰,主要是先定义个DataSetInfo的类,里面包含了单字和单字索引的映射关系,还有最大文本长度。在finally部分,我们使用UNK代表所有未登录词。接着看下MultiDataSet的构造:

private static List<org.nd4j.linalg.dataset.api.MultiDataSet> getMultiDataIter(String filePath, DataSetInfo dataInfo) {List<org.nd4j.linalg.dataset.api.MultiDataSet> list = new LinkedList<>();try(BufferedReader br = Files.newReader(new File(filePath), Charset.forName("UTF-8"))){String line = null;int lineIndex = 0;while( (line = br.readLine()) != null ) {if( lineIndex == 0 ) {lineIndex++;continue;}String[] splits = line.split("\t");String query = splits[0];String doc = splits[1];String label = splits[2];if( query == null || query.isEmpty() ||doc == null || doc.isEmpty() || label == null)continue;//double[][] queryIndexArray = new double[1][dataInfo.queryMaxLen];double[][] docIndexArray = new double[1][dataInfo.docMaxLen];double[][] queryIndexMaskArray = new double[1][dataInfo.queryMaxLen];double[][] docIndexMaskArray = new double[1][dataInfo.docMaxLen];double[][] labelIndexArray = new double[1][1];//for( int i = 0; i < query.length(); ++i ) {queryIndexArray[0][i] = dataInfo.queryDict.getOrDefault(String.valueOf(query.charAt(i)),dataInfo.queryDict.get("UNK"));queryIndexMaskArray[0][i] = 1.0;}for( int i = 0; i < doc.length(); ++i ) {docIndexArray[0][i] = dataInfo.docDict.getOrDefault(String.valueOf(doc.charAt(i)),dataInfo.docDict.get("UNK"));docIndexMaskArray[0][i] = 1.0;}labelIndexArray[0][0] = Double.parseDouble(label);//org.nd4j.linalg.dataset.api.MultiDataSet mds = new MultiDataSet(new INDArray[] {Nd4j.create(queryIndexArray), Nd4j.create(docIndexArray)},new INDArray[] {Nd4j.create(labelIndexArray)},new INDArray[] {Nd4j.create(queryIndexMaskArray), Nd4j.create(docIndexMaskArray)},null);list.add(mds);}}catch(Exception ex) {ex.printStackTrace();}return list;
}

该部分逻辑主要是通过一个静态方法来读取训练文本中的每一行数据,并且针对text_a和text_b以及label转换成一个MultiDataSet对象,并存储在一个LinkedList对象中。需要注意的是Mask部分的处理。Mask张量中用1.0代表有效,0.0代表无效的部分。下面我们看下训练建模和评估的部分。

final int batchSize = 256;
final int embedding_size = 64;
DataSetInfo dataInfo = preprocess("data/lcqmc/train.tsv");
ComputationGraph dssm = getDSSM(dataInfo.queryDict.size(), dataInfo.docDict.size(), embedding_size);
System.out.println(dssm.summary());
List<org.nd4j.linalg.dataset.api.MultiDataSet> trainDataList = getMultiDataIter("data/lcqmc/train.tsv", dataInfo);
List<org.nd4j.linalg.dataset.api.MultiDataSet> testDataList = getMultiDataIter("data/lcqmc/test.tsv", dataInfo);
System.out.println("Finish Loading Train Data");
for(int epoch = 0; epoch < 5; ++epoch) {Collections.shuffle(trainDataList);MultiDataSetIterator trainIter = new IteratorMultiDataSetIterator(trainDataList.iterator(), batchSize);dssm.fit(trainIter);Evaluation eval = dssm.evaluate(new IteratorMultiDataSetIterator(testDataList.iterator(), batchSize));System.out.println(eval);
}

通过10个epoch的训练,我们最终在验证集上得到70%左右的准确率, loss值在0.4左右。
在这里插入图片描述

5. 总结

DSSM是一个经典的双塔模型,但其也有明显的缺点,就是两个塔之间是独立的,没有信息的交叉。这种信息的交叉对应推荐场景来说是很重要的。DSSM论文中的结构比较简单,是MLP为主,且输入层使用词袋模型进行处理,这其实忽略的上下文的语义信息,因此我们在实现的时候,使用LSTM模型来捕获序列的完整语义信息。当然,由于时间原因,我们这边并没有做分词处理,相信经过分词处理,在LCQMC数据集上的准确率可以进一步得到提升。另外,双塔的结构可以很灵活,内部可以直接上BERT来做,这里变体就太多,不做过多陈述了。

这篇关于Deeplearning4j 实战 (22):基于DSSM的语义匹配建模的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/560693

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

滚雪球学Java(87):Java事务处理:JDBC的ACID属性与实战技巧!真有两下子!

咦咦咦,各位小可爱,我是你们的好伙伴——bug菌,今天又来给大家普及Java SE啦,别躲起来啊,听我讲干货还不快点赞,赞多了我就有动力讲得更嗨啦!所以呀,养成先点赞后阅读的好习惯,别被干货淹没了哦~ 🏆本文收录于「滚雪球学Java」专栏,专业攻坚指数级提升,助你一臂之力,带你早日登顶🚀,欢迎大家关注&&收藏!持续更新中,up!up!up!! 环境说明:Windows 10

hdu 3065 AC自动机 匹配串编号以及出现次数

题意: 仍旧是天朝语题。 Input 第一行,一个整数N(1<=N<=1000),表示病毒特征码的个数。 接下来N行,每行表示一个病毒特征码,特征码字符串长度在1—50之间,并且只包含“英文大写字符”。任意两个病毒特征码,不会完全相同。 在这之后一行,表示“万恶之源”网站源码,源码字符串长度在2000000之内。字符串中字符都是ASCII码可见字符(不包括回车)。

二分最大匹配总结

HDU 2444  黑白染色 ,二分图判定 const int maxn = 208 ;vector<int> g[maxn] ;int n ;bool vis[maxn] ;int match[maxn] ;;int color[maxn] ;int setcolor(int u , int c){color[u] = c ;for(vector<int>::iter

基于UE5和ROS2的激光雷达+深度RGBD相机小车的仿真指南(五):Blender锥桶建模

前言 本系列教程旨在使用UE5配置一个具备激光雷达+深度摄像机的仿真小车,并使用通过跨平台的方式进行ROS2和UE5仿真的通讯,达到小车自主导航的目的。本教程默认有ROS2导航及其gazebo仿真相关方面基础,Nav2相关的学习教程可以参考本人的其他博客Nav2代价地图实现和原理–Nav2源码解读之CostMap2D(上)-CSDN博客往期教程: 第一期:基于UE5和ROS2的激光雷达+深度RG

POJ 3057 最大二分匹配+bfs + 二分

SampleInput35 5XXDXXX...XD...XX...DXXXXX5 12XXXXXXXXXXXXX..........DX.XXXXXXXXXXX..........XXXXXXXXXXXXX5 5XDXXXX.X.DXX.XXD.X.XXXXDXSampleOutput321impossible

数学建模笔记—— 非线性规划

数学建模笔记—— 非线性规划 非线性规划1. 模型原理1.1 非线性规划的标准型1.2 非线性规划求解的Matlab函数 2. 典型例题3. matlab代码求解3.1 例1 一个简单示例3.2 例2 选址问题1. 第一问 线性规划2. 第二问 非线性规划 非线性规划 非线性规划是一种求解目标函数或约束条件中有一个或几个非线性函数的最优化问题的方法。运筹学的一个重要分支。2