NLP实践——文本生成中停不下来的问题

2023-10-23 08:59

本文主要是介绍NLP实践——文本生成中停不下来的问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

NLP实践——文本生成中停不下来的问题

  • 1. 问题概述
  • 2. 造成的原因
  • 3. 解决的方法
  • 4. 效果

1. 问题概述

对于NLG任务,在推理阶段可能经常会遇到“停不下来”的问题,即重复的token被反复预测出来。
例如,输入“Google”,翻译模型可能会翻译为“谷歌谷歌”。

这个问题已经有很多人研究很久了,在模型侧提出的应对方案也有很多,本文介绍最简便的一种处理方法,只需要添加一行代码,就可以有效地改善。

2. 造成的原因

对于这种现象出现的原因,有很多相关的分析和介绍,其中苏神的这篇文章让我感到受益匪浅,从数学的角度分析了为什么会重复,非常建议大家读一下这篇文章。

3. 解决的方法

其实在transformers的源码中,以及预置了一个参数,用来控制对重复出现token的惩罚,思想非常朴素,最早应该是出现在CTRL的论文中:
https://arxiv.org/pdf/1909.05858.pdf

我们来看一下论文里是怎么描述的:
ctrl
在生成的时候,就是在计算词表中词汇的概率嘛,如果我们不希望之前出现的token连续出现,那只要把出现过的token对应的得分,人为地降低就好了,也就是给它一个惩罚的力度,让它变小一点。

反应在代码中,就是transformers/generation_utils.py中的GenerationMixin.generate方法,其中的repetition_penalty参数,就是用来控制这个惩罚的,也就是论文中的theta。

这个参数必须为大于0的浮点数,当取值为1.0时,相当于什么也没有做。如果在调用generate的时候给了这个参数,则会创建一个RepetitionPenaltyLogitsProcessor,简单看一下这个Processor是如何运作的:

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):r""":class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.Args:repetition_penalty (:obj:`float`):The parameter for repetition penalty. 1.0 means no penalty. See `this paper<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details."""def __init__(self, penalty: float):if not isinstance(penalty, float) or not (penalty > 0):raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")self.penalty = penaltydef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:score = torch.gather(scores, 1, input_ids)# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probabilityscore = torch.where(score < 0, score * self.penalty, score / self.penalty)scores.scatter_(1, input_ids, score)return scores

其中input_ids就是generate时,输入的input_ids, scores是每一步推理计算出来的为下一步提供的得分。简单来说,这个类就是根据输入序列的token id,把score里边对应位置的得分取出来,然后惩罚一下这些位置的得分,让它的得分变小,然后把惩罚过的分数,替换掉原来计算出来的得分。

4. 效果

还是以翻译模型为例,采用的模型是opus-mt-en-zh,实例化这个模型:

from transformers import AutoModelWithLMHead,AutoTokenizer
mode_name = 'liam168/trans-opus-mt-en-zh'
model = AutoModelWithLMHead.from_pretrained(mode_name)
tokenizer = AutoTokenizer.from_pretrained(mode_name)

翻译一个词:

text = 'Google'
batch = tokenizer.prepare_seq2seq_batch(src_texts=[text], return_tensors='pt', max_length=512)
translation = model.generate(**batch)
res = tokenizer.batch_decode(translation, skip_special_tokens=True)

翻译结果为“谷歌谷歌”。可以看到,当输入文本很短时,很容易就出现了重复。

而如果在generate的时候,增加一个参数:

text = 'Google'
batch = tokenizer.prepare_seq2seq_batch(src_texts=[text], return_tensors='pt', max_length=512)
batch['repetition_penalty'] = 1.2   # 论文中默认的参数1.2
translation = model.generate(**batch)
res = tokenizer.batch_decode(translation, skip_special_tokens=True)

翻译结果就变成了只有一个"谷歌"。

再大胆一点,如果把惩罚力度设置为无穷大,也会出问题。当设置惩罚为float('inf')时,在翻译句子“Google has Google translate”的时候,就会变成“谷歌有Google翻译”,第二个Google就因为被惩罚了而没有翻译成谷歌,而如果惩罚为1.2,则翻译结果为“谷歌有谷歌翻译”。所以惩罚力度设置为多大,还需要自己把握一下。

这篇关于NLP实践——文本生成中停不下来的问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/266847

相关文章

Java编译生成多个.class文件的原理和作用

《Java编译生成多个.class文件的原理和作用》作为一名经验丰富的开发者,在Java项目中执行编译后,可能会发现一个.java源文件有时会产生多个.class文件,从技术实现层面详细剖析这一现象... 目录一、内部类机制与.class文件生成成员内部类(常规内部类)局部内部类(方法内部类)匿名内部类二、

使用Jackson进行JSON生成与解析的新手指南

《使用Jackson进行JSON生成与解析的新手指南》这篇文章主要为大家详细介绍了如何使用Jackson进行JSON生成与解析处理,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 核心依赖2. 基础用法2.1 对象转 jsON(序列化)2.2 JSON 转对象(反序列化)3.

springboot循环依赖问题案例代码及解决办法

《springboot循环依赖问题案例代码及解决办法》在SpringBoot中,如果两个或多个Bean之间存在循环依赖(即BeanA依赖BeanB,而BeanB又依赖BeanA),会导致Spring的... 目录1. 什么是循环依赖?2. 循环依赖的场景案例3. 解决循环依赖的常见方法方法 1:使用 @La

Spring Boot 配置文件之类型、加载顺序与最佳实践记录

《SpringBoot配置文件之类型、加载顺序与最佳实践记录》SpringBoot的配置文件是灵活且强大的工具,通过合理的配置管理,可以让应用开发和部署更加高效,无论是简单的属性配置,还是复杂... 目录Spring Boot 配置文件详解一、Spring Boot 配置文件类型1.1 applicatio

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为

在java中如何将inputStream对象转换为File对象(不生成本地文件)

《在java中如何将inputStream对象转换为File对象(不生成本地文件)》:本文主要介绍在java中如何将inputStream对象转换为File对象(不生成本地文件),具有很好的参考价... 目录需求说明问题解决总结需求说明在后端中通过POI生成Excel文件流,将输出流(outputStre

tomcat多实例部署的项目实践

《tomcat多实例部署的项目实践》Tomcat多实例是指在一台设备上运行多个Tomcat服务,这些Tomcat相互独立,本文主要介绍了tomcat多实例部署的项目实践,具有一定的参考价值,感兴趣的可... 目录1.创建项目目录,测试文China编程件2js.创建实例的安装目录3.准备实例的配置文件4.编辑实例的

SpringBoot启动报错的11个高频问题排查与解决终极指南

《SpringBoot启动报错的11个高频问题排查与解决终极指南》这篇文章主要为大家详细介绍了SpringBoot启动报错的11个高频问题的排查与解决,文中的示例代码讲解详细,感兴趣的小伙伴可以了解一... 目录1. 依赖冲突:NoSuchMethodError 的终极解法2. Bean注入失败:No qu

Python 中的异步与同步深度解析(实践记录)

《Python中的异步与同步深度解析(实践记录)》在Python编程世界里,异步和同步的概念是理解程序执行流程和性能优化的关键,这篇文章将带你深入了解它们的差异,以及阻塞和非阻塞的特性,同时通过实际... 目录python中的异步与同步:深度解析与实践异步与同步的定义异步同步阻塞与非阻塞的概念阻塞非阻塞同步

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1