对话式AI——多轮对话拼接

2023-10-28 07:30
文章标签 ai 拼接 对话 多轮

本文主要是介绍对话式AI——多轮对话拼接,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1 介绍

        oppo 举办的上下文拼接算法        比赛官网
在这里插入图片描述

1.1 比赛任务:

        本次比赛使用OPPO小布助手开放的“对话式指代消解与省略恢复”数据集。数据集中包括了3万条对话交互数据。每条数据样本提供三轮对话,分别是上轮query、上轮应答和本轮query,选手需要使用算法技术将本轮query(即第三轮)处理成上下文无关的query。

1.2 数据介绍:

本数据集为训练集,包含以下内容:
        每行采用json格式,用于表示一个样本。每条训练数据由query-01、response-01、query-02、query-02-rewrite四部分组成,分别是上轮query、上轮应答,本轮query,本轮query对应的上下文无关的query。具体格式举例:

{"query-01": "你喜欢周杰伦吗","response-01": "喜欢呀","query-02": "来唱首他的歌","query-02-rewrite": "来唱首周杰伦的歌"
}
# 输入:query-01、response-01、query-02
# 输出:query-02-rewrite

2 数据预处理

        将数据传入迭代器

class data_generator(DataGenerator):"""数据生成器"""def __iter__(self, random=False):batch_token_ids, batch_segment_ids = [], []for is_end, item in self.sample(random):q1 = item['query-01']r1 = item['response-01']q2 = item['query-02']qr = item['query-02-rewrite']token_ids, segment_ids = [tokenizer._token_start_id], [0]ids_q1 = tokenizer.encode(q1)[0][1:]ids_r1 = tokenizer.encode(r1)[0][1:]ids_q2 = tokenizer.encode(q2)[0][1:]ids_qr = tokenizer.encode(qr)[0][1:]token_ids.extend(ids_q1)segment_ids.extend([0] * len(ids_q1))token_ids.extend(ids_r1)segment_ids.extend([1] * len(ids_r1))token_ids.extend(ids_q2)segment_ids.extend([0] * len(ids_q2))token_ids.extend(ids_qr)segment_ids.extend([1] * len(ids_qr))batch_token_ids.append(token_ids)batch_segment_ids.append(segment_ids)if len(batch_token_ids) == self.batch_size or is_end:batch_token_ids = sequence_padding(batch_token_ids)batch_segment_ids = sequence_padding(batch_segment_ids)yield [batch_token_ids, batch_segment_ids], Nonebatch_token_ids, batch_segment_ids = [], []

3 模型

        使用苏神bert4keras框架,模型是unilm,直接用单个Bert的架构做Seq2Seq,详细介绍请转入从语言模型到Seq2Seq。
        使用unilm搭建成本次比赛的模型,模型结构如下:
在这里插入图片描述

        模型搭建代码:

class CrossEntropy(Loss):"""交叉熵作为loss,并mask掉输入部分"""def compute_loss(self, inputs, mask=None):y_true, y_mask, y_pred = inputsy_true = y_true[:, 1:]  # 目标token_idsy_mask = y_mask[:, 1:]  # segment_ids,刚好指示了要预测的部分y_pred = y_pred[:, :-1]  # 预测序列,错开一位loss = K.sparse_categorical_crossentropy(y_true, y_pred)loss = K.sum(loss * y_mask) / K.sum(y_mask)return lossmodel = build_transformer_model(config_path,checkpoint_path,application='unilm',keep_tokens=keep_tokens,  # 只保留keep_tokens中的字,精简原字表
)
output = CrossEntropy(2)(model.inputs + model.outputs)
model = Model(model.inputs, output)
model.compile(optimizer=Adam(1e-5))
model.summary()

4 解码器

        使用beam_search来解码

class AutoTitle(AutoRegressiveDecoder):"""seq2seq解码器"""@AutoRegressiveDecoder.wraps(default_rtype='probas')def predict(self, inputs, output_ids, states):token_ids, segment_ids = inputstoken_ids = np.concatenate([token_ids, output_ids], 1)segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)return self.last_token(model).predict([token_ids, segment_ids])def generate(self, item, topk=1):q1 = item['query-01']r1 = item['response-01']q2 = item['query-02']qr = item['query-02-rewrite']token_ids, segment_ids = [tokenizer._token_start_id], [0]ids_q1 = tokenizer.encode(q1)[0][1:]ids_r1 = tokenizer.encode(r1)[0][1:]ids_q2 = tokenizer.encode(q2)[0][1:]ids_qr = tokenizer.encode(qr)[0][1:]token_ids.extend(ids_q1)segment_ids.extend([0] * len(ids_q1))token_ids.extend(ids_r1)segment_ids.extend([1] * len(ids_r1))token_ids.extend(ids_q2)segment_ids.extend([0] * len(ids_q2))# token_ids.extend(ids_qr)# segment_ids.extend([1] * len(ids_qr))output_ids = self.beam_search([token_ids, segment_ids], topk=topk)  # 基于beam searchreturn tokenizer.decode(output_ids)autotitle = AutoTitle(start_id=None, end_id=tokenizer._token_end_id, maxlen=32)

5 训练

class Evaluator(keras.callbacks.Callback):"""评估与保存"""def __init__(self):self.lowest = 1e10def on_epoch_end(self, epoch, logs=None):# 保存最优if logs['loss'] <= self.lowest:self.lowest = logs['loss']save_name = r'data/best_model_%d.weights' % (epoch)model.save_weights(save_name)# 演示效果just_show()file = open('data/train.txt', 'r', encoding='utf-8')
train_data = [json.loads(line, encoding='utf-8') for line in file.read().split('\n')]
# train_data = train_data[:100]
evaluator = Evaluator()
train_generator = data_generator(train_data, batch_size)
model.fit(train_generator.forfit(),steps_per_epoch=len(train_generator),epochs=epochs,callbacks=[evaluator]
)

本次比赛主要是参与为主,没有花费太多时间,只提交了5次,官网中的准确率在86%以上。
在这里插入图片描述

这篇关于对话式AI——多轮对话拼接的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/291655

相关文章

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

python中字符串拼接的几种方法及优缺点对比详解

《python中字符串拼接的几种方法及优缺点对比详解》在Python中,字符串拼接是常见的操作,Python提供了多种方法来拼接字符串,每种方法有其优缺点和适用场景,以下是几种常见的字符串拼接方法,需... 目录1. 使用 + 运算符示例:优缺点:2. 使用&nbsjsp;join() 方法示例:优缺点:3

Spring AI ectorStore的使用流程

《SpringAIectorStore的使用流程》SpringAI中的VectorStore是一种用于存储和检索高维向量数据的数据库或存储解决方案,它在AI应用中发挥着至关重要的作用,本文给大家介... 目录一、VectorStore的基本概念二、VectorStore的核心接口三、VectorStore的

Golang中拼接字符串的6种方式性能对比

《Golang中拼接字符串的6种方式性能对比》golang的string类型是不可修改的,对于拼接字符串来说,本质上还是创建一个新的对象将数据放进去,主要有6种拼接方式,下面小编就来为大家详细讲讲吧... 目录拼接方式介绍性能对比测试代码测试结果源码分析golang的string类型是不可修改的,对于拼接字

Spring AI集成DeepSeek三步搞定Java智能应用的详细过程

《SpringAI集成DeepSeek三步搞定Java智能应用的详细过程》本文介绍了如何使用SpringAI集成DeepSeek,一个国内顶尖的多模态大模型,SpringAI提供了一套统一的接口,简... 目录DeepSeek 介绍Spring AI 是什么?Spring AI 的主要功能包括1、环境准备2

Spring AI集成DeepSeek实现流式输出的操作方法

《SpringAI集成DeepSeek实现流式输出的操作方法》本文介绍了如何在SpringBoot中使用Sse(Server-SentEvents)技术实现流式输出,后端使用SpringMVC中的S... 目录一、后端代码二、前端代码三、运行项目小天有话说题外话参考资料前面一篇文章我们实现了《Spring

Spring AI与DeepSeek实战一之快速打造智能对话应用

《SpringAI与DeepSeek实战一之快速打造智能对话应用》本文详细介绍了如何通过SpringAI框架集成DeepSeek大模型,实现普通对话和流式对话功能,步骤包括申请API-KEY、项目搭... 目录一、概述二、申请DeepSeek的API-KEY三、项目搭建3.1. 开发环境要求3.2. mav

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

Spring AI集成DeepSeek的详细步骤

《SpringAI集成DeepSeek的详细步骤》DeepSeek作为一款卓越的国产AI模型,越来越多的公司考虑在自己的应用中集成,对于Java应用来说,我们可以借助SpringAI集成DeepSe... 目录DeepSeek 介绍Spring AI 是什么?1、环境准备2、构建项目2.1、pom依赖2.2

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll