NLP实践——文本生成中停不下来的问题

2023-10-23 08:59

本文主要是介绍NLP实践——文本生成中停不下来的问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

NLP实践——文本生成中停不下来的问题

  • 1. 问题概述
  • 2. 造成的原因
  • 3. 解决的方法
  • 4. 效果

1. 问题概述

对于NLG任务,在推理阶段可能经常会遇到“停不下来”的问题,即重复的token被反复预测出来。
例如,输入“Google”,翻译模型可能会翻译为“谷歌谷歌”。

这个问题已经有很多人研究很久了,在模型侧提出的应对方案也有很多,本文介绍最简便的一种处理方法,只需要添加一行代码,就可以有效地改善。

2. 造成的原因

对于这种现象出现的原因,有很多相关的分析和介绍,其中苏神的这篇文章让我感到受益匪浅,从数学的角度分析了为什么会重复,非常建议大家读一下这篇文章。

3. 解决的方法

其实在transformers的源码中,以及预置了一个参数,用来控制对重复出现token的惩罚,思想非常朴素,最早应该是出现在CTRL的论文中:
https://arxiv.org/pdf/1909.05858.pdf

我们来看一下论文里是怎么描述的:
ctrl
在生成的时候,就是在计算词表中词汇的概率嘛,如果我们不希望之前出现的token连续出现,那只要把出现过的token对应的得分,人为地降低就好了,也就是给它一个惩罚的力度,让它变小一点。

反应在代码中,就是transformers/generation_utils.py中的GenerationMixin.generate方法,其中的repetition_penalty参数,就是用来控制这个惩罚的,也就是论文中的theta。

这个参数必须为大于0的浮点数,当取值为1.0时,相当于什么也没有做。如果在调用generate的时候给了这个参数,则会创建一个RepetitionPenaltyLogitsProcessor,简单看一下这个Processor是如何运作的:

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):r""":class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.Args:repetition_penalty (:obj:`float`):The parameter for repetition penalty. 1.0 means no penalty. See `this paper<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details."""def __init__(self, penalty: float):if not isinstance(penalty, float) or not (penalty > 0):raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")self.penalty = penaltydef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:score = torch.gather(scores, 1, input_ids)# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probabilityscore = torch.where(score < 0, score * self.penalty, score / self.penalty)scores.scatter_(1, input_ids, score)return scores

其中input_ids就是generate时,输入的input_ids, scores是每一步推理计算出来的为下一步提供的得分。简单来说,这个类就是根据输入序列的token id,把score里边对应位置的得分取出来,然后惩罚一下这些位置的得分,让它的得分变小,然后把惩罚过的分数,替换掉原来计算出来的得分。

4. 效果

还是以翻译模型为例,采用的模型是opus-mt-en-zh,实例化这个模型:

from transformers import AutoModelWithLMHead,AutoTokenizer
mode_name = 'liam168/trans-opus-mt-en-zh'
model = AutoModelWithLMHead.from_pretrained(mode_name)
tokenizer = AutoTokenizer.from_pretrained(mode_name)

翻译一个词:

text = 'Google'
batch = tokenizer.prepare_seq2seq_batch(src_texts=[text], return_tensors='pt', max_length=512)
translation = model.generate(**batch)
res = tokenizer.batch_decode(translation, skip_special_tokens=True)

翻译结果为“谷歌谷歌”。可以看到,当输入文本很短时,很容易就出现了重复。

而如果在generate的时候,增加一个参数:

text = 'Google'
batch = tokenizer.prepare_seq2seq_batch(src_texts=[text], return_tensors='pt', max_length=512)
batch['repetition_penalty'] = 1.2   # 论文中默认的参数1.2
translation = model.generate(**batch)
res = tokenizer.batch_decode(translation, skip_special_tokens=True)

翻译结果就变成了只有一个"谷歌"。

再大胆一点,如果把惩罚力度设置为无穷大,也会出问题。当设置惩罚为float('inf')时,在翻译句子“Google has Google translate”的时候,就会变成“谷歌有Google翻译”,第二个Google就因为被惩罚了而没有翻译成谷歌,而如果惩罚为1.2,则翻译结果为“谷歌有谷歌翻译”。所以惩罚力度设置为多大,还需要自己把握一下。

这篇关于NLP实践——文本生成中停不下来的问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/266847

相关文章

基于MySQL Binlog的Elasticsearch数据同步实践

一、为什么要做 随着马蜂窝的逐渐发展,我们的业务数据越来越多,单纯使用 MySQL 已经不能满足我们的数据查询需求,例如对于商品、订单等数据的多维度检索。 使用 Elasticsearch 存储业务数据可以很好的解决我们业务中的搜索需求。而数据进行异构存储后,随之而来的就是数据同步的问题。 二、现有方法及问题 对于数据同步,我们目前的解决方案是建立数据中间表。把需要检索的业务数据,统一放到一张M

好题——hdu2522(小数问题:求1/n的第一个循环节)

好喜欢这题,第一次做小数问题,一开始真心没思路,然后参考了网上的一些资料。 知识点***********************************无限不循环小数即无理数,不能写作两整数之比*****************************(一开始没想到,小学没学好) 此题1/n肯定是一个有限循环小数,了解这些后就能做此题了。 按照除法的机制,用一个函数表示出来就可以了,代码如下

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

购买磨轮平衡机时应该注意什么问题和技巧

在购买磨轮平衡机时,您应该注意以下几个关键点: 平衡精度 平衡精度是衡量平衡机性能的核心指标,直接影响到不平衡量的检测与校准的准确性,从而决定磨轮的振动和噪声水平。高精度的平衡机能显著减少振动和噪声,提高磨削加工的精度。 转速范围 宽广的转速范围意味着平衡机能够处理更多种类的磨轮,适应不同的工作条件和规格要求。 振动监测能力 振动监测能力是评估平衡机性能的重要因素。通过传感器实时监

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

poj 1287 Networking(prim or kruscal最小生成树)

题意给你点与点间距离,求最小生成树。 注意点是,两点之间可能有不同的路,输入的时候选择最小的,和之前有道最短路WA的题目类似。 prim代码: #include<stdio.h>const int MaxN = 51;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int P;int prim(){bool vis[MaxN];

poj 2349 Arctic Network uva 10369(prim or kruscal最小生成树)

题目很麻烦,因为不熟悉最小生成树的算法调试了好久。 感觉网上的题目解释都没说得很清楚,不适合新手。自己写一个。 题意:给你点的坐标,然后两点间可以有两种方式来通信:第一种是卫星通信,第二种是无线电通信。 卫星通信:任何两个有卫星频道的点间都可以直接建立连接,与点间的距离无关; 无线电通信:两个点之间的距离不能超过D,无线电收发器的功率越大,D越大,越昂贵。 计算无线电收发器D

缓存雪崩问题

缓存雪崩是缓存中大量key失效后当高并发到来时导致大量请求到数据库,瞬间耗尽数据库资源,导致数据库无法使用。 解决方案: 1、使用锁进行控制 2、对同一类型信息的key设置不同的过期时间 3、缓存预热 1. 什么是缓存雪崩 缓存雪崩是指在短时间内,大量缓存数据同时失效,导致所有请求直接涌向数据库,瞬间增加数据库的负载压力,可能导致数据库性能下降甚至崩溃。这种情况往往发生在缓存中大量 k