微调预训练模型方式的文本语义匹配(Further Pretraining Bert)

2023-10-23 08:59

本文主要是介绍微调预训练模型方式的文本语义匹配(Further Pretraining Bert),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

微调预训练模型方式的文本语义匹配(Further Pretraining Bert)

今年带着小伙伴参加了天池赛道三: 小布助手对话短文本语义匹配比赛,虽然最后没有杀进B榜,但也是预料之中的结果,最后成绩在110名左右,还算能接受。

言归正传,本文会解说苏剑林(苏神)的Baseline方案和代码,然后会分享我在Baseline上使用的tricks还有我们的方案和实验结果。

干货
Github:https://github.com/Ludong418/gaic-2021-task3
资料集合:个人整理的资料大全

苏神的Baseline

苏神在第二次发布数据没几天就公布了自己的Baseline,线上成绩大概是86左右,他的方案是mlm和文本语义匹配两个任务同时进行。

方案
词典不匹配问题

由于官方发布的数据是脱敏的,所以不建议大家直接使用其他已经训练好的预训练模型,当然也有人直接使用了,也有一定的效果,也有人好奇为什么脱敏数据不要用已经训练好的预训练模型呢?因为存在脱敏数据和预训练词典不匹配问题。nlp在深度学习预处理有一步很重要的过程就是 tokens 转 ids,而脱敏数据是已经转好的 ids,也就是下面形式。所以例子中的 ‘2’ 代表的意思是 ‘我’,若使用了预训练模型字典的 ‘2’,那就确实是 ‘2’ 的含义了。
在这里插入图片描述
所以我们就需要重新训练一个预训练模型。但是我们都知道一个新的预训练模型需要大量的数据和算力,如何去做呢?

如何训练脱敏数据的预训练模型

我们可以只保留模型参数部分,而tokens embeddings table可以替换掉,举个例子就是预训练模型中20000多个embedded table,随机换成自己数据词典大小的6000多个tokens,苏神的替换方式是保留了bert语料中token频数top6000+的embedded,当然我们也可以随机初始化6000多个embedded,两种方式我发现效果差不多。
在这里插入图片描述
但是别忘几个特殊的token要加入字典中,‘no’ 和 ‘yes’ 就是文本对的标签,‘相似’ 和 ‘不相似’,这两个token很关键。

0: pad, 1: unk, 2: cls, 3: sep, 4: mask, 5: no, 6: yes

最后就可以训练一个mlm(masked language model)了,但是苏神不只是就是训练一个mlm,而是在训练mlm过程中随便把文本语义匹配也做了,它使用第一个token([cls])的输出作为文本语义匹配任务的输出。

模型代码
预处理

这份数据的预处理过程比较简单,重点还是讲解mlm任务的输入格式,注意到output_ids第一个token要 +5, 目的就是用 [cls] 来预测 yes 或者 no。

def sample_convert(text1, text2, label, random=False):"""转换为MLM格式"""text1_ids = [tokens.get(t, 1) for t in text1]text2_ids = [tokens.get(t, 1) for t in text2]if random:if np.random.random() < 0.5:text1_ids, text2_ids = text2_ids, text1_idstext1_ids, out1_ids = random_mask(text1_ids)text2_ids, out2_ids = random_mask(text2_ids)else:out1_ids = [0] * len(text1_ids)out2_ids = [0] * len(text2_ids)token_ids = [2] + text1_ids + [3] + text2_ids + [3]segment_ids = [0] * len(token_ids)# +5 目的就是用 [cls] 来预测 yes 或者 nooutput_ids = [label + 5] + out1_ids + [0] + out2_ids + [0]return token_ids, segment_ids, output_ids
模型

模型是一个mlm任务,就是希望被masked的token进行预测真实的token,完成一个完形填空的任务,而我们希望输出的第一个token([cls])用来预测 ‘yes’ 或者 ‘no’ 这两个token。
在这里插入图片描述
以下就是评估模型的代码:

def evaluate(data):"""线下评测函数"""Y_true, Y_pred = [], []for x_true, y_true in data:y_pred = model.predict(x_true)[:, 0, 5:7]y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)y_true = y_true[:, 0] - 5Y_pred.extend(y_pred)Y_true.extend(y_true)return roc_auc_score(Y_true, Y_pred)

注意到这行代码,其实就是选择了output的no和yes所在的维度的值进行预测的。

# y_pred shape:[batch_size, max_seq_len, voc_szie] 
y_pred = model.predict(x_true)[:, 0, 5:7]

在这里插入图片描述

我的方案

我们的做法差不多,只不过是预训练模型和文本语义匹配分开做了,我们先使用了脱敏数据和nazha预训练模型微调了一个新的预训练模型,然后利用新的预训练模型完成文本语义匹配二分类任务。

Further Pretraining

预训练模型使用了中文nezha-pytorch版本训练好的模型进行微调,使用了NeZhaForMaskedLM class 做mlm任务,但是源码是对全词计算loss,所以我修改了计算loss方法,只对masked的token计算loss。修改代码如下:

if labels is not None:# 只对mask的部分进行计算masked_lm_positions = torch.where(labels.view(-1) != 0)loss_fct = CrossEntropyLoss()  masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size)[masked_lm_positions],labels.view(-1)[masked_lm_positions])outputs = (masked_lm_loss,) + outputsreturn outputs  # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
文本语义匹配

利用新的预训练模型做二分类任务,做法和常规的bert分类没有啥区别,最后线上效果能达到90左右,要提高模型的效果还得需要一些小tricks。

trick 1:数据增强

目的是为了增加训练数据数量,效果能略有提升零点几个点。
句子等价替换:如果 句子A = 句子B, 句子B = 句子C, 则 句子A = 句子C
句子对调:把所有的句子对调

trick 2:对抗式学习

我们测试了FGM和VAT两种对抗学习,效果都有提升1个点左右,VAT也可以使用在mlm任务中,但是发现效果并不是很好。

trick 3:伪标签

利用训练好的模型对test数据集预测,把输出概率大于阈值的数据加入到训练集中训练,效果提升不明显。

trick 4:半监督学习(没有实现)

可以考虑使用半监督学习,例如mixtext、mixup等模型,目的也是为了增大训练模型样本,但这次比赛中没来得及去实验,但是在工作中的一个项目中使用了mixtext模型,效果有很大的提升。

trick 4:置信学习(没有实现)

数据集中总会或多或少出现错误标签,若能剔除掉这些错误标签在进行训练,效果会有一定的提升,所以可以考虑使用置信学习的方法去发现错误标签。

结论

本次比赛没能杀入B榜主要还是因为身为打工人,我们只能在下班和周末搞搞,很多想法并不能有足够的时间去实验,不像高校里的学生论文看的多,时间也足够。其次就是算力没跟上,显卡不够多,一个mlm模型要花一天多才能训练完成,然后加入对抗式学习速度更慢了。

整个任务的难度不大,一开始碰到脱敏数据我也懵了一阵,我也想到自己要训练一个预训练模型,但是不敢确定效果,但最后看来效果确实不错,所以深度学习这门技术还是得靠动手验证的科学,最后推荐一篇不错的论文Don’t Stop Pretraining: Adapt Language Models to Domains and Tasks。

这篇关于微调预训练模型方式的文本语义匹配(Further Pretraining Bert)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/266848

相关文章

java两个List的交集,并集方式

《java两个List的交集,并集方式》文章主要介绍了Java中两个List的交集和并集的处理方法,推荐使用Apache的CollectionUtils工具类,因为它简单且不会改变原有集合,同时,文章... 目录Java两个List的交集,并集方法一方法二方法三总结java两个List的交集,并集方法一

Python中如何控制小数点精度与对齐方式

《Python中如何控制小数点精度与对齐方式》在Python编程中,数据输出格式化是一个常见的需求,尤其是在涉及到小数点精度和对齐方式时,下面小编就来为大家介绍一下如何在Python中实现这些功能吧... 目录一、控制小数点精度1. 使用 round() 函数2. 使用字符串格式化二、控制对齐方式1. 使用

Nginx中location实现多条件匹配的方法详解

《Nginx中location实现多条件匹配的方法详解》在Nginx中,location指令用于匹配请求的URI,虽然location本身是基于单一匹配规则的,但可以通过多种方式实现多个条件的匹配逻辑... 目录1. 概述2. 实现多条件匹配的方式2.1 使用多个 location 块2.2 使用正则表达式

Nginx配置系统服务&设置环境变量方式

《Nginx配置系统服务&设置环境变量方式》本文介绍了如何将Nginx配置为系统服务并设置环境变量,以便更方便地对Nginx进行操作,通过配置系统服务,可以使用系统命令来启动、停止或重新加载Nginx... 目录1.Nginx操作问题2.配置系统服android务3.设置环境变量总结1.Nginx操作问题

Go 1.23中Timer无buffer的实现方式详解

《Go1.23中Timer无buffer的实现方式详解》在Go1.23中,Timer的实现通常是通过time包提供的time.Timer类型来实现的,本文主要介绍了Go1.23中Timer无buff... 目录Timer 的基本实现无缓冲区的实现自定义无缓冲 Timer 实现更复杂的 Timer 实现总结在

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

nginx upstream六种方式分配小结

《nginxupstream六种方式分配小结》本文主要介绍了nginxupstream六种方式分配小结,包括轮询、加权轮询、IP哈希、公平轮询、URL哈希和备份服务器,具有一定的参考价格,感兴趣的可... 目录1 轮询(默认)2 weight3 ip_hash4 fair(第三方)5 url_hash(第三

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot

linux打包解压命令方式

《linux打包解压命令方式》文章介绍了Linux系统中常用的打包和解压命令,包括tar和zip,使用tar命令可以创建和解压tar格式的归档文件,使用zip命令可以创建和解压zip格式的压缩文件,每... 目录Lijavascriptnux 打包和解压命令打包命令解压命令总结linux 打包和解压命令打

Python中常用的四种取整方式分享

《Python中常用的四种取整方式分享》在数据处理和数值计算中,取整操作是非常常见的需求,Python提供了多种取整方式,本文为大家整理了四种常用的方法,希望对大家有所帮助... 目录引言向零取整(Truncate)向下取整(Floor)向上取整(Ceil)四舍五入(Round)四种取整方式的对比综合示例应