对话式AI——多轮对话拼接

2023-10-28 07:30
文章标签 ai 拼接 对话 多轮

本文主要是介绍对话式AI——多轮对话拼接,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1 介绍

        oppo 举办的上下文拼接算法        比赛官网
在这里插入图片描述

1.1 比赛任务:

        本次比赛使用OPPO小布助手开放的“对话式指代消解与省略恢复”数据集。数据集中包括了3万条对话交互数据。每条数据样本提供三轮对话,分别是上轮query、上轮应答和本轮query,选手需要使用算法技术将本轮query(即第三轮)处理成上下文无关的query。

1.2 数据介绍:

本数据集为训练集,包含以下内容:
        每行采用json格式,用于表示一个样本。每条训练数据由query-01、response-01、query-02、query-02-rewrite四部分组成,分别是上轮query、上轮应答,本轮query,本轮query对应的上下文无关的query。具体格式举例:

{"query-01": "你喜欢周杰伦吗","response-01": "喜欢呀","query-02": "来唱首他的歌","query-02-rewrite": "来唱首周杰伦的歌"
}
# 输入:query-01、response-01、query-02
# 输出:query-02-rewrite

2 数据预处理

        将数据传入迭代器

class data_generator(DataGenerator):"""数据生成器"""def __iter__(self, random=False):batch_token_ids, batch_segment_ids = [], []for is_end, item in self.sample(random):q1 = item['query-01']r1 = item['response-01']q2 = item['query-02']qr = item['query-02-rewrite']token_ids, segment_ids = [tokenizer._token_start_id], [0]ids_q1 = tokenizer.encode(q1)[0][1:]ids_r1 = tokenizer.encode(r1)[0][1:]ids_q2 = tokenizer.encode(q2)[0][1:]ids_qr = tokenizer.encode(qr)[0][1:]token_ids.extend(ids_q1)segment_ids.extend([0] * len(ids_q1))token_ids.extend(ids_r1)segment_ids.extend([1] * len(ids_r1))token_ids.extend(ids_q2)segment_ids.extend([0] * len(ids_q2))token_ids.extend(ids_qr)segment_ids.extend([1] * len(ids_qr))batch_token_ids.append(token_ids)batch_segment_ids.append(segment_ids)if len(batch_token_ids) == self.batch_size or is_end:batch_token_ids = sequence_padding(batch_token_ids)batch_segment_ids = sequence_padding(batch_segment_ids)yield [batch_token_ids, batch_segment_ids], Nonebatch_token_ids, batch_segment_ids = [], []

3 模型

        使用苏神bert4keras框架,模型是unilm,直接用单个Bert的架构做Seq2Seq,详细介绍请转入从语言模型到Seq2Seq。
        使用unilm搭建成本次比赛的模型,模型结构如下:
在这里插入图片描述

        模型搭建代码:

class CrossEntropy(Loss):"""交叉熵作为loss,并mask掉输入部分"""def compute_loss(self, inputs, mask=None):y_true, y_mask, y_pred = inputsy_true = y_true[:, 1:]  # 目标token_idsy_mask = y_mask[:, 1:]  # segment_ids,刚好指示了要预测的部分y_pred = y_pred[:, :-1]  # 预测序列,错开一位loss = K.sparse_categorical_crossentropy(y_true, y_pred)loss = K.sum(loss * y_mask) / K.sum(y_mask)return lossmodel = build_transformer_model(config_path,checkpoint_path,application='unilm',keep_tokens=keep_tokens,  # 只保留keep_tokens中的字,精简原字表
)
output = CrossEntropy(2)(model.inputs + model.outputs)
model = Model(model.inputs, output)
model.compile(optimizer=Adam(1e-5))
model.summary()

4 解码器

        使用beam_search来解码

class AutoTitle(AutoRegressiveDecoder):"""seq2seq解码器"""@AutoRegressiveDecoder.wraps(default_rtype='probas')def predict(self, inputs, output_ids, states):token_ids, segment_ids = inputstoken_ids = np.concatenate([token_ids, output_ids], 1)segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)return self.last_token(model).predict([token_ids, segment_ids])def generate(self, item, topk=1):q1 = item['query-01']r1 = item['response-01']q2 = item['query-02']qr = item['query-02-rewrite']token_ids, segment_ids = [tokenizer._token_start_id], [0]ids_q1 = tokenizer.encode(q1)[0][1:]ids_r1 = tokenizer.encode(r1)[0][1:]ids_q2 = tokenizer.encode(q2)[0][1:]ids_qr = tokenizer.encode(qr)[0][1:]token_ids.extend(ids_q1)segment_ids.extend([0] * len(ids_q1))token_ids.extend(ids_r1)segment_ids.extend([1] * len(ids_r1))token_ids.extend(ids_q2)segment_ids.extend([0] * len(ids_q2))# token_ids.extend(ids_qr)# segment_ids.extend([1] * len(ids_qr))output_ids = self.beam_search([token_ids, segment_ids], topk=topk)  # 基于beam searchreturn tokenizer.decode(output_ids)autotitle = AutoTitle(start_id=None, end_id=tokenizer._token_end_id, maxlen=32)

5 训练

class Evaluator(keras.callbacks.Callback):"""评估与保存"""def __init__(self):self.lowest = 1e10def on_epoch_end(self, epoch, logs=None):# 保存最优if logs['loss'] <= self.lowest:self.lowest = logs['loss']save_name = r'data/best_model_%d.weights' % (epoch)model.save_weights(save_name)# 演示效果just_show()file = open('data/train.txt', 'r', encoding='utf-8')
train_data = [json.loads(line, encoding='utf-8') for line in file.read().split('\n')]
# train_data = train_data[:100]
evaluator = Evaluator()
train_generator = data_generator(train_data, batch_size)
model.fit(train_generator.forfit(),steps_per_epoch=len(train_generator),epochs=epochs,callbacks=[evaluator]
)

本次比赛主要是参与为主,没有花费太多时间,只提交了5次,官网中的准确率在86%以上。
在这里插入图片描述

这篇关于对话式AI——多轮对话拼接的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/291655

相关文章

Ubuntu系统怎么安装Warp? 新一代AI 终端神器安装使用方法

《Ubuntu系统怎么安装Warp?新一代AI终端神器安装使用方法》Warp是一款使用Rust开发的现代化AI终端工具,该怎么再Ubuntu系统中安装使用呢?下面我们就来看看详细教程... Warp Terminal 是一款使用 Rust 开发的现代化「AI 终端」工具。最初它只支持 MACOS,但在 20

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

AI绘图怎么变现?想做点副业的小白必看!

在科技飞速发展的今天,AI绘图作为一种新兴技术,不仅改变了艺术创作的方式,也为创作者提供了多种变现途径。本文将详细探讨几种常见的AI绘图变现方式,帮助创作者更好地利用这一技术实现经济收益。 更多实操教程和AI绘画工具,可以扫描下方,免费获取 定制服务:个性化的创意商机 个性化定制 AI绘图技术能够根据用户需求生成个性化的头像、壁纸、插画等作品。例如,姓氏头像在电商平台上非常受欢迎,

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

AI行业应用(不定期更新)

ChatPDF 可以让你上传一个 PDF 文件,然后针对这个 PDF 进行小结和提问。你可以把各种各样你要研究的分析报告交给它,快速获取到想要知道的信息。https://www.chatpdf.com/

【北交大信息所AI-Max2】使用方法

BJTU信息所集群AI_MAX2使用方法 使用的前提是预约到相应的算力卡,拥有登录权限的账号密码,一般为导师组共用一个。 有浏览器、ssh工具就可以。 1.新建集群Terminal 浏览器登陆10.126.62.75 (如果是1集群把75改成66) 交互式开发 执行器选Terminal 密码随便设一个(需记住) 工作空间:私有数据、全部文件 加速器选GeForce_RTX_2080_Ti

AI Toolkit + H100 GPU,一小时内微调最新热门文生图模型 FLUX

上个月,FLUX 席卷了互联网,这并非没有原因。他们声称优于 DALLE 3、Ideogram 和 Stable Diffusion 3 等模型,而这一点已被证明是有依据的。随着越来越多的流行图像生成工具(如 Stable Diffusion Web UI Forge 和 ComyUI)开始支持这些模型,FLUX 在 Stable Diffusion 领域的扩展将会持续下去。 自 FLU