微调预训练模型方式的文本语义匹配(Further Pretraining Bert)

2023-10-23 08:59

本文主要是介绍微调预训练模型方式的文本语义匹配(Further Pretraining Bert),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

微调预训练模型方式的文本语义匹配(Further Pretraining Bert)

今年带着小伙伴参加了天池赛道三: 小布助手对话短文本语义匹配比赛,虽然最后没有杀进B榜,但也是预料之中的结果,最后成绩在110名左右,还算能接受。

言归正传,本文会解说苏剑林(苏神)的Baseline方案和代码,然后会分享我在Baseline上使用的tricks还有我们的方案和实验结果。

干货
Github:https://github.com/Ludong418/gaic-2021-task3
资料集合:个人整理的资料大全

苏神的Baseline

苏神在第二次发布数据没几天就公布了自己的Baseline,线上成绩大概是86左右,他的方案是mlm和文本语义匹配两个任务同时进行。

方案
词典不匹配问题

由于官方发布的数据是脱敏的,所以不建议大家直接使用其他已经训练好的预训练模型,当然也有人直接使用了,也有一定的效果,也有人好奇为什么脱敏数据不要用已经训练好的预训练模型呢?因为存在脱敏数据和预训练词典不匹配问题。nlp在深度学习预处理有一步很重要的过程就是 tokens 转 ids,而脱敏数据是已经转好的 ids,也就是下面形式。所以例子中的 ‘2’ 代表的意思是 ‘我’,若使用了预训练模型字典的 ‘2’,那就确实是 ‘2’ 的含义了。
在这里插入图片描述
所以我们就需要重新训练一个预训练模型。但是我们都知道一个新的预训练模型需要大量的数据和算力,如何去做呢?

如何训练脱敏数据的预训练模型

我们可以只保留模型参数部分,而tokens embeddings table可以替换掉,举个例子就是预训练模型中20000多个embedded table,随机换成自己数据词典大小的6000多个tokens,苏神的替换方式是保留了bert语料中token频数top6000+的embedded,当然我们也可以随机初始化6000多个embedded,两种方式我发现效果差不多。
在这里插入图片描述
但是别忘几个特殊的token要加入字典中,‘no’ 和 ‘yes’ 就是文本对的标签,‘相似’ 和 ‘不相似’,这两个token很关键。

0: pad, 1: unk, 2: cls, 3: sep, 4: mask, 5: no, 6: yes

最后就可以训练一个mlm(masked language model)了,但是苏神不只是就是训练一个mlm,而是在训练mlm过程中随便把文本语义匹配也做了,它使用第一个token([cls])的输出作为文本语义匹配任务的输出。

模型代码
预处理

这份数据的预处理过程比较简单,重点还是讲解mlm任务的输入格式,注意到output_ids第一个token要 +5, 目的就是用 [cls] 来预测 yes 或者 no。

def sample_convert(text1, text2, label, random=False):"""转换为MLM格式"""text1_ids = [tokens.get(t, 1) for t in text1]text2_ids = [tokens.get(t, 1) for t in text2]if random:if np.random.random() < 0.5:text1_ids, text2_ids = text2_ids, text1_idstext1_ids, out1_ids = random_mask(text1_ids)text2_ids, out2_ids = random_mask(text2_ids)else:out1_ids = [0] * len(text1_ids)out2_ids = [0] * len(text2_ids)token_ids = [2] + text1_ids + [3] + text2_ids + [3]segment_ids = [0] * len(token_ids)# +5 目的就是用 [cls] 来预测 yes 或者 nooutput_ids = [label + 5] + out1_ids + [0] + out2_ids + [0]return token_ids, segment_ids, output_ids
模型

模型是一个mlm任务,就是希望被masked的token进行预测真实的token,完成一个完形填空的任务,而我们希望输出的第一个token([cls])用来预测 ‘yes’ 或者 ‘no’ 这两个token。
在这里插入图片描述
以下就是评估模型的代码:

def evaluate(data):"""线下评测函数"""Y_true, Y_pred = [], []for x_true, y_true in data:y_pred = model.predict(x_true)[:, 0, 5:7]y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)y_true = y_true[:, 0] - 5Y_pred.extend(y_pred)Y_true.extend(y_true)return roc_auc_score(Y_true, Y_pred)

注意到这行代码,其实就是选择了output的no和yes所在的维度的值进行预测的。

# y_pred shape:[batch_size, max_seq_len, voc_szie] 
y_pred = model.predict(x_true)[:, 0, 5:7]

在这里插入图片描述

我的方案

我们的做法差不多,只不过是预训练模型和文本语义匹配分开做了,我们先使用了脱敏数据和nazha预训练模型微调了一个新的预训练模型,然后利用新的预训练模型完成文本语义匹配二分类任务。

Further Pretraining

预训练模型使用了中文nezha-pytorch版本训练好的模型进行微调,使用了NeZhaForMaskedLM class 做mlm任务,但是源码是对全词计算loss,所以我修改了计算loss方法,只对masked的token计算loss。修改代码如下:

if labels is not None:# 只对mask的部分进行计算masked_lm_positions = torch.where(labels.view(-1) != 0)loss_fct = CrossEntropyLoss()  masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size)[masked_lm_positions],labels.view(-1)[masked_lm_positions])outputs = (masked_lm_loss,) + outputsreturn outputs  # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
文本语义匹配

利用新的预训练模型做二分类任务,做法和常规的bert分类没有啥区别,最后线上效果能达到90左右,要提高模型的效果还得需要一些小tricks。

trick 1:数据增强

目的是为了增加训练数据数量,效果能略有提升零点几个点。
句子等价替换:如果 句子A = 句子B, 句子B = 句子C, 则 句子A = 句子C
句子对调:把所有的句子对调

trick 2:对抗式学习

我们测试了FGM和VAT两种对抗学习,效果都有提升1个点左右,VAT也可以使用在mlm任务中,但是发现效果并不是很好。

trick 3:伪标签

利用训练好的模型对test数据集预测,把输出概率大于阈值的数据加入到训练集中训练,效果提升不明显。

trick 4:半监督学习(没有实现)

可以考虑使用半监督学习,例如mixtext、mixup等模型,目的也是为了增大训练模型样本,但这次比赛中没来得及去实验,但是在工作中的一个项目中使用了mixtext模型,效果有很大的提升。

trick 4:置信学习(没有实现)

数据集中总会或多或少出现错误标签,若能剔除掉这些错误标签在进行训练,效果会有一定的提升,所以可以考虑使用置信学习的方法去发现错误标签。

结论

本次比赛没能杀入B榜主要还是因为身为打工人,我们只能在下班和周末搞搞,很多想法并不能有足够的时间去实验,不像高校里的学生论文看的多,时间也足够。其次就是算力没跟上,显卡不够多,一个mlm模型要花一天多才能训练完成,然后加入对抗式学习速度更慢了。

整个任务的难度不大,一开始碰到脱敏数据我也懵了一阵,我也想到自己要训练一个预训练模型,但是不敢确定效果,但最后看来效果确实不错,所以深度学习这门技术还是得靠动手验证的科学,最后推荐一篇不错的论文Don’t Stop Pretraining: Adapt Language Models to Domains and Tasks。

这篇关于微调预训练模型方式的文本语义匹配(Further Pretraining Bert)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/266848

相关文章

Java中List转Map的几种具体实现方式和特点

《Java中List转Map的几种具体实现方式和特点》:本文主要介绍几种常用的List转Map的方式,包括使用for循环遍历、Java8StreamAPI、ApacheCommonsCollect... 目录前言1、使用for循环遍历:2、Java8 Stream API:3、Apache Commons

虚拟机与物理机的文件共享方式

《虚拟机与物理机的文件共享方式》文章介绍了如何在KaliLinux虚拟机中实现物理机文件夹的直接挂载,以便在虚拟机中方便地读取和使用物理机上的文件,通过设置和配置,可以实现临时挂载和永久挂载,并提供... 目录虚拟机与物理机的文件共享1 虚拟机设置2 验证Kali下分享文件夹功能是否启用3 创建挂载目录4

linux报错INFO:task xxxxxx:634 blocked for more than 120 seconds.三种解决方式

《linux报错INFO:taskxxxxxx:634blockedformorethan120seconds.三种解决方式》文章描述了一个Linux最小系统运行时出现的“hung_ta... 目录1.问题描述2.解决办法2.1 缩小文件系统缓存大小2.2 修改系统IO调度策略2.3 取消120秒时间限制3

Linux alias的三种使用场景方式

《Linuxalias的三种使用场景方式》文章介绍了Linux中`alias`命令的三种使用场景:临时别名、用户级别别名和系统级别别名,临时别名仅在当前终端有效,用户级别别名在当前用户下所有终端有效... 目录linux alias三种使用场景一次性适用于当前用户全局生效,所有用户都可调用删除总结Linux

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

Mybatis官方生成器的使用方式

《Mybatis官方生成器的使用方式》本文详细介绍了MyBatisGenerator(MBG)的使用方法,通过实际代码示例展示了如何配置Maven插件来自动化生成MyBatis项目所需的实体类、Map... 目录1. MyBATis Generator 简介2. MyBatis Generator 的功能3

通过C#获取PDF中指定文本或所有文本的字体信息

《通过C#获取PDF中指定文本或所有文本的字体信息》在设计和出版行业中,字体的选择和使用对最终作品的质量有着重要影响,然而,有时我们可能会遇到包含未知字体的PDF文件,这使得我们无法准确地复制或修改文... 目录引言C# 获取PDF中指定文本的字体信息C# 获取PDF文档中用到的所有字体信息引言在设计和出

Python数据处理之导入导出Excel数据方式

《Python数据处理之导入导出Excel数据方式》Python是Excel数据处理的绝佳工具,通过Pandas和Openpyxl等库可以实现数据的导入、导出和自动化处理,从基础的数据读取和清洗到复杂... 目录python导入导出Excel数据开启数据之旅:为什么Python是Excel数据处理的最佳拍档

SpringBoot项目启动后自动加载系统配置的多种实现方式

《SpringBoot项目启动后自动加载系统配置的多种实现方式》:本文主要介绍SpringBoot项目启动后自动加载系统配置的多种实现方式,并通过代码示例讲解的非常详细,对大家的学习或工作有一定的... 目录1. 使用 CommandLineRunner实现方式:2. 使用 ApplicationRunne

VUE动态绑定class类的三种常用方式及适用场景详解

《VUE动态绑定class类的三种常用方式及适用场景详解》文章介绍了在实际开发中动态绑定class的三种常见情况及其解决方案,包括根据不同的返回值渲染不同的class样式、给模块添加基础样式以及根据设... 目录前言1.动态选择class样式(对象添加:情景一)2.动态添加一个class样式(字符串添加:情