Deeplearning4j 实战 (22):基于DSSM的语义匹配建模

2024-01-01 22:36

本文主要是介绍Deeplearning4j 实战 (22):基于DSSM的语义匹配建模,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Deeplearning4j 实战 (22):基于DSSM的语义匹配建模

Eclipse Deeplearning4j GitChat课程:Deeplearning4j 快速入门_专栏
Eclipse Deeplearning4j 系列博客:万宫玺的专栏_wangongxi_CSDN博客
Eclipse Deeplearning4j Github:https://github.com/eclipse/deeplearning4j
Eclipse Deeplearning4j 社区:https://community.konduit.ai/

DSSM是微软在2013年提出的,最早用于搜索引擎语义召回的双塔模型。目前在工业界也广泛用于推荐召回、搜索相关性排序、语义召回等环节。DSSM是一个轻量级模型,在线上serving的时候,可以通过对query向量和doc向量计算内积,得到的相似值用来衡量query和doc的相似度,从而进行进一步的排序。下面就分别从DSSM模型结构、基于DL4J的DSSM建模、对开源数据集LCQMC的建模等几个环节来介绍如何使用DSSM模型。当然,由于DSSM模型的论文发表时间较早,发表时给出的模型结构比较简单,在我们具体实现的时候,会做一些调整,具体在介绍模型搭建的部分会提到。

1. DSSM模型简述

在论文中,query和doc分别通过各自独立的神经网络映射成一个语义向量。需要注意的是,原论文中doc是一个包含正样本和负样本的集合。正样本取1个,负样本取4个。论文中有提到,正样本是搜素后被点击的样本,负样本则是随机选取的搜索未被点击的样本集合。通过分别计算query的语义向量和正负doc样本的语义向量的余弦相似度,再通过softmax函数得到正负样本的概率分布后,和label计算交叉熵损失。这就是DSSM模型的大致的idea。下面先看下论文中对于DSSM描绘的架构图:
在这里插入图片描述
通过模型架构图可以看到,论文中是使用最简单的MLP对输入进行映射。这里需要提一下word hashing的操作。由于2013年时候word embedding技术还不是较广泛的使用,因此论文中的word hashing是在n-gram语言模型的基础上,通过hash操作将接近50W的词表计算每个词的索引值。这在当时是一种比较高效的做法,目前由于硬件的进步以及embedding技术的进一步成熟,可以直接使用预训练的embedding向量或者做端到端的建模。因此,在第三部分中构建DSSM模型的过程中,我们也是使用的端到端的方案。
在这里插入图片描述

上面这张截图中模型训练的有关描述。就像在本节开始时候提到的,通过softmax计算query和每个doc的余弦相似度的值归一化概率分布。由于softmax函数与cosine相似度的一致性,因此相似度越高的query-doc pair,其softmax值也会越接近于1。在损失函数部分,使用的是经典的log loss。这部分没啥说的。
另外需要说明的是,从ranking loss的角度,论文中的loss应当属于list-wise loss。当然,如果将负样本减少到一个或者doc集合中只有一个正样本或负样本(softmax更改为sigmoid函数),那就退化成pair-wise loss或者point-wise loss。为了方便起见,在第三部分的建模过程中,我们会使用point-wise loss。
对于搜索场景来说,双塔的输入分别是query和doc。对于推荐场景来说,双塔的输入可以是user和item或者item和item,用于U2I的召回或者I2I的召回。

2. LCQMC数据集

LCQMC是哈工大和阿里共同开源的用于QA的数据集,详情可参见论文。下载链接为:地址。压缩包中共有三个文件,三个文件都是以制表符作为分隔符。我们先来看下用于训练的部分数据的截图:
在这里插入图片描述
文件中有三列。最后一列用1或者0来代表 text_a 和 text_b两列文本的是否相关。如果把text_a列文本看作是query,那text_b列可以看作是doc。用于验证的文件中的内容也和训练文件中的数据格式相似,这里就不做另外截图了。

最后提一下,训练样本数量是:238767,验证的样本数量是:12501。

3. 基于DL4J的DSSM模型构建

在第一部分中,我们提到DSSM的论文中双塔内部是使用MLP结构。考虑到MLP结构的单一性,我们使用Embedding+LSTM+MLP的结构作为双塔的内部结构。虽然query和doc对应的塔结构相同,但是不做参数的共享。另外,由于LCQMC数据集中label是1或者0,因此我们将DSSM的输出层改为sigmoid + binary cross entropy loss。具体我们先给出代码片段:

private static ComputationGraph getDSSM(final int QUERY_VOCAB_SIZE, final int DOC_VOCAB_SIZE, final int VECTOR_SIZE) {ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Adam(5 * 1e-3)).weightInit(WeightInit.XAVIER).seed(12345L).graphBuilder().addInputs("query", "doc").setInputTypes(InputType.recurrent(QUERY_VOCAB_SIZE), InputType.recurrent(DOC_VOCAB_SIZE)).addLayer("query-embedding", new EmbeddingSequenceLayer.Builder().nIn(QUERY_VOCAB_SIZE + 1).nOut(VECTOR_SIZE).build(), "query").addLayer("query-embedding-lstm", new LSTM.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE).activation(Activation.TANH).build(), "query-embedding").addLayer("doc-embedding", new EmbeddingSequenceLayer.Builder().nIn(DOC_VOCAB_SIZE + 1).nOut(VECTOR_SIZE).build(), "doc").addLayer("doc-embedding-lstm", new LSTM.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE).activation(Activation.TANH).build(), "doc-embedding").addVertex("query-embedding-lstm-last-output", new LastTimeStepVertex("query"), "query-embedding-lstm").addVertex("doc-embedding-lstm-last-output", new LastTimeStepVertex("doc"), "doc-embedding-lstm").addLayer("query-output", new DenseLayer.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE / 2).activation(Activation.LEAKYRELU).build(), "query-embedding-lstm-last-output").addLayer("doc-output", new DenseLayer.Builder().nIn(VECTOR_SIZE).nOut(VECTOR_SIZE / 2).activation(Activation.LEAKYRELU).build(), "doc-embedding-lstm-last-output").addVertex("query-output-l2-norm", new L2NormalizeVertex(), "query-output").addVertex("doc-output-l2-norm", new L2NormalizeVertex(), "doc-output").addVertex("cosing-similar", new ElementWiseVertex(ElementWiseVertex.Op.Product), "query-output-l2-norm", "doc-output-l2-norm").addLayer("out", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.XENT)	//bce.nIn(VECTOR_SIZE / 2).nOut(1).activation(Activation.SIGMOID).build(), "cosing-similar").setOutputs("out").build();ComputationGraph net = new ComputationGraph(conf);net.setListeners(new ScoreIterationListener(1));net.init();return net;
}

由于存在两个输入,因此使用DL4J中的ComputationGraph。这里需要说明的有几点:

  • LastTimeStepVertex的作用:获取LSTM最后一个time step输出的张量
  • L2NormalizeVertex的作用:L2归一化,将query和doc的向量转化为单位向量
  • ElementWiseVertex的作用:通过设置Op为点积,实际为计算query和doc单位向量的内积,因此L2NormalizeVertex + ElementWiseVertex联合起来的作用是计算向量间的余弦相似度值
  • 输出端使用sigmoid + bce 作point-wise的损失函数
    在这里插入图片描述

上面的截图中通过summary接口打印的模型结构和待训练参数。可见待训练参数68W。

另外,对于该静态方法,输入的几个参数QUERY_VOCAB_SIZE,DOC_VOCAB_SIZE,VECTOR_SIZE分别代表LCQMC数据集中text_a的词表大小和text_b的词表大小,以及词向量的大小。

需要指出的是,在第四部分进行建模的操作中,我们使用中文单字作为query和doc的最小粒度特征,而不做分词的处理。

4. DSSM模型训练和评估

首先介绍下数据处理的部分:

  • 读取训练文件,构建中文单字和单字的索引,存储在map结构中。同时记录最长的文本长度,用于后续的padding操作。
  • 再次读取文件,对每条记录构建MultiDataSet对象,并存储在LinkedList对象中。MultiDataSet对象中会存储query和doc作为输入,label作为输出,此外还有query和doc的mask张量,用于统一变长文本的处理。

我们看下具体的实现逻辑:

class DataSetInfo{public Map<String,Integer> queryDict = new TreeMap<>();public Map<String,Integer> docDict = new TreeMap<>();public int queryMaxLen = 0;public int docMaxLen = 0;
}private static DataSetInfo preprocess(String filePath) {DataSetInfo info = new DataSetInfo();try(BufferedReader br = Files.newReader(new File(filePath), Charset.forName("UTF-8"))){String line = null;int lineIndex = 0;while( (line = br.readLine()) != null ) {if( lineIndex == 0 ) {lineIndex++;continue;}String[] splits = line.split("\t");if( null == splits || splits.length != 3 )continue;String query = splits[0];String doc = splits[1];if( query != null && query.length() > 0 ) {info.queryMaxLen = Math.max(query.length(), info.queryMaxLen);for( char c : query.toCharArray() ) {String charStr = String.valueOf(c);if( !info.queryDict.containsKey(charStr) ) {int curIndex = info.queryDict.size();info.queryDict.put(charStr, curIndex);}}}if( doc != null && doc.length() > 0 ) {info.docMaxLen = Math.max(doc.length(), info.docMaxLen);for( char c : doc.toCharArray() ) {String charStr = String.valueOf(c);if( !info.docDict.containsKey(charStr) ) {int curIndex = info.docDict.size();info.docDict.put(charStr, curIndex);}}}}}catch(Exception ex) {ex.printStackTrace();}finally {int curIndex = info.queryDict.size();info.queryDict.put("UNK", curIndex);//curIndex = info.docDict.size();info.docDict.put("UNK", curIndex);}return info;
}

这部分处理逻辑比较清晰,主要是先定义个DataSetInfo的类,里面包含了单字和单字索引的映射关系,还有最大文本长度。在finally部分,我们使用UNK代表所有未登录词。接着看下MultiDataSet的构造:

private static List<org.nd4j.linalg.dataset.api.MultiDataSet> getMultiDataIter(String filePath, DataSetInfo dataInfo) {List<org.nd4j.linalg.dataset.api.MultiDataSet> list = new LinkedList<>();try(BufferedReader br = Files.newReader(new File(filePath), Charset.forName("UTF-8"))){String line = null;int lineIndex = 0;while( (line = br.readLine()) != null ) {if( lineIndex == 0 ) {lineIndex++;continue;}String[] splits = line.split("\t");String query = splits[0];String doc = splits[1];String label = splits[2];if( query == null || query.isEmpty() ||doc == null || doc.isEmpty() || label == null)continue;//double[][] queryIndexArray = new double[1][dataInfo.queryMaxLen];double[][] docIndexArray = new double[1][dataInfo.docMaxLen];double[][] queryIndexMaskArray = new double[1][dataInfo.queryMaxLen];double[][] docIndexMaskArray = new double[1][dataInfo.docMaxLen];double[][] labelIndexArray = new double[1][1];//for( int i = 0; i < query.length(); ++i ) {queryIndexArray[0][i] = dataInfo.queryDict.getOrDefault(String.valueOf(query.charAt(i)),dataInfo.queryDict.get("UNK"));queryIndexMaskArray[0][i] = 1.0;}for( int i = 0; i < doc.length(); ++i ) {docIndexArray[0][i] = dataInfo.docDict.getOrDefault(String.valueOf(doc.charAt(i)),dataInfo.docDict.get("UNK"));docIndexMaskArray[0][i] = 1.0;}labelIndexArray[0][0] = Double.parseDouble(label);//org.nd4j.linalg.dataset.api.MultiDataSet mds = new MultiDataSet(new INDArray[] {Nd4j.create(queryIndexArray), Nd4j.create(docIndexArray)},new INDArray[] {Nd4j.create(labelIndexArray)},new INDArray[] {Nd4j.create(queryIndexMaskArray), Nd4j.create(docIndexMaskArray)},null);list.add(mds);}}catch(Exception ex) {ex.printStackTrace();}return list;
}

该部分逻辑主要是通过一个静态方法来读取训练文本中的每一行数据,并且针对text_a和text_b以及label转换成一个MultiDataSet对象,并存储在一个LinkedList对象中。需要注意的是Mask部分的处理。Mask张量中用1.0代表有效,0.0代表无效的部分。下面我们看下训练建模和评估的部分。

final int batchSize = 256;
final int embedding_size = 64;
DataSetInfo dataInfo = preprocess("data/lcqmc/train.tsv");
ComputationGraph dssm = getDSSM(dataInfo.queryDict.size(), dataInfo.docDict.size(), embedding_size);
System.out.println(dssm.summary());
List<org.nd4j.linalg.dataset.api.MultiDataSet> trainDataList = getMultiDataIter("data/lcqmc/train.tsv", dataInfo);
List<org.nd4j.linalg.dataset.api.MultiDataSet> testDataList = getMultiDataIter("data/lcqmc/test.tsv", dataInfo);
System.out.println("Finish Loading Train Data");
for(int epoch = 0; epoch < 5; ++epoch) {Collections.shuffle(trainDataList);MultiDataSetIterator trainIter = new IteratorMultiDataSetIterator(trainDataList.iterator(), batchSize);dssm.fit(trainIter);Evaluation eval = dssm.evaluate(new IteratorMultiDataSetIterator(testDataList.iterator(), batchSize));System.out.println(eval);
}

通过10个epoch的训练,我们最终在验证集上得到70%左右的准确率, loss值在0.4左右。
在这里插入图片描述

5. 总结

DSSM是一个经典的双塔模型,但其也有明显的缺点,就是两个塔之间是独立的,没有信息的交叉。这种信息的交叉对应推荐场景来说是很重要的。DSSM论文中的结构比较简单,是MLP为主,且输入层使用词袋模型进行处理,这其实忽略的上下文的语义信息,因此我们在实现的时候,使用LSTM模型来捕获序列的完整语义信息。当然,由于时间原因,我们这边并没有做分词处理,相信经过分词处理,在LCQMC数据集上的准确率可以进一步得到提升。另外,双塔的结构可以很灵活,内部可以直接上BERT来做,这里变体就太多,不做过多陈述了。

这篇关于Deeplearning4j 实战 (22):基于DSSM的语义匹配建模的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/560693

相关文章

Python实战之屏幕录制功能的实现

《Python实战之屏幕录制功能的实现》屏幕录制,即屏幕捕获,是指将计算机屏幕上的活动记录下来,生成视频文件,本文主要为大家介绍了如何使用Python实现这一功能,希望对大家有所帮助... 目录屏幕录制原理图像捕获音频捕获编码压缩输出保存完整的屏幕录制工具高级功能实时预览增加水印多平台支持屏幕录制原理屏幕

最新Spring Security实战教程之Spring Security安全框架指南

《最新SpringSecurity实战教程之SpringSecurity安全框架指南》SpringSecurity是Spring生态系统中的核心组件,提供认证、授权和防护机制,以保护应用免受各种安... 目录前言什么是Spring Security?同类框架对比Spring Security典型应用场景传统

最新Spring Security实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)

《最新SpringSecurity实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)》本章节介绍了如何通过SpringSecurity实现从配置自定义登录页面、表单登录处理逻辑的配置,并简单模拟... 目录前言改造准备开始登录页改造自定义用户名密码登陆成功失败跳转问题自定义登出前后端分离适配方案结语前言

OpenManus本地部署实战亲测有效完全免费(最新推荐)

《OpenManus本地部署实战亲测有效完全免费(最新推荐)》文章介绍了如何在本地部署OpenManus大语言模型,包括环境搭建、LLM编程接口配置和测试步骤,本文给大家讲解的非常详细,感兴趣的朋友一... 目录1.概况2.环境搭建2.1安装miniconda或者anaconda2.2 LLM编程接口配置2

基于Canvas的Html5多时区动态时钟实战代码

《基于Canvas的Html5多时区动态时钟实战代码》:本文主要介绍了如何使用Canvas在HTML5上实现一个多时区动态时钟的web展示,通过Canvas的API,可以绘制出6个不同城市的时钟,并且这些时钟可以动态转动,每个时钟上都会标注出对应的24小时制时间,详细内容请阅读本文,希望能对你有所帮助...

Spring AI与DeepSeek实战一之快速打造智能对话应用

《SpringAI与DeepSeek实战一之快速打造智能对话应用》本文详细介绍了如何通过SpringAI框架集成DeepSeek大模型,实现普通对话和流式对话功能,步骤包括申请API-KEY、项目搭... 目录一、概述二、申请DeepSeek的API-KEY三、项目搭建3.1. 开发环境要求3.2. mav

Nginx中location实现多条件匹配的方法详解

《Nginx中location实现多条件匹配的方法详解》在Nginx中,location指令用于匹配请求的URI,虽然location本身是基于单一匹配规则的,但可以通过多种方式实现多个条件的匹配逻辑... 目录1. 概述2. 实现多条件匹配的方式2.1 使用多个 location 块2.2 使用正则表达式

Python与DeepSeek的深度融合实战

《Python与DeepSeek的深度融合实战》Python作为最受欢迎的编程语言之一,以其简洁易读的语法、丰富的库和广泛的应用场景,成为了无数开发者的首选,而DeepSeek,作为人工智能领域的新星... 目录一、python与DeepSeek的结合优势二、模型训练1. 数据准备2. 模型架构与参数设置3

Java实战之利用POI生成Excel图表

《Java实战之利用POI生成Excel图表》ApachePOI是Java生态中处理Office文档的核心工具,这篇文章主要为大家详细介绍了如何在Excel中创建折线图,柱状图,饼图等常见图表,需要的... 目录一、环境配置与依赖管理二、数据源准备与工作表构建三、图表生成核心步骤1. 折线图(Line Ch

golang字符串匹配算法解读

《golang字符串匹配算法解读》文章介绍了字符串匹配算法的原理,特别是Knuth-Morris-Pratt(KMP)算法,该算法通过构建模式串的前缀表来减少匹配时的不必要的字符比较,从而提高效率,在... 目录简介KMP实现代码总结简介字符串匹配算法主要用于在一个较长的文本串中查找一个较短的字符串(称为