以beam search为例,详解transformers中generate方法(下)

2023-11-03 07:20

本文主要是介绍以beam search为例,详解transformers中generate方法(下),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

以beam search为例,详解transformers中generate方法(下)

  • 1. beam search原理回顾
  • 2. 代码流程概览
  • 3. BeamSearchScorer
  • 4. BeamHypotheses
  • 5. beam_search过程
    • 5.1 beam score初始化
    • 5.2 准备输入
    • 5.3 前向forward
    • 5.4 计算下一个step每个token的得分
    • 5.5 选择next token
    • 5.6 更新beam状态
    • 5.7 后处理finalize
  • 6. beam sample
  • 7. 总结

在上一篇博客中,对generate方法的基本流程逻辑进行了介绍,本文将继续之前的内容,介绍最常用的采样策略beam search是如何实现的。


1. beam search原理回顾

Beam search的原理并不复杂,可以理解为在Greedy search的基础上扩大了搜索范围。Greedy search在每一步只保留概率最大的top-1的结果,而beam search则是在此基础上,每一步保留了beam_size个结果。

例如,词表空间内总共有这几个token:[‘早’, ‘上’, ‘好’]。设置k=2,则在每一步的生成中,保留概率最大的2个结果如图所示。

2

2. 代码流程概览

为了帮助大家阅读代码,这里把这部分代码的整体逻辑进行一下梳理,如下图所示:
代码逻辑
总的来说,生成过程中不断重复调用模型的forward()计算出logits,以及调用BeamSearchScorer的process()来计算下一个位置每个token出现的得分,来生成下一个token及其概率分布,直到满足终止条件,结束生成。

3. BeamSearchScorer

BeamSearchScorer是在生成过程进行状态维护的类,它的作用是用来更新Beam得分,以及判断生成过程是否结束等。在这一节中,简单了解一下这个类的构造,具体的使用方法会在本篇的第5节中,结合beam search的整个流程的推进,进行更加详细的介绍。

先简单解释一下其参数:

参数名类型含义
batch_sizeint批量生成时一次处理多少条数据
num_beamsint每一条数据在生成时保留几个beam
devicetorch.devicecpu or cuda
length_penaltyOptional[float]控制倾向于生成更长的句子还是更短的句子
do_early_stoppingOptional[Union[bool, str]]早停机制,是否生成达到num_beam后立即停止
num_beam_hyps_to_keepint最终返回多少个beam
num_beam_groupsint把所有的beam按照差异度分成多少组
max_lengthint生成的最大长度

这个类除了构造方法之外,只有两个方法和一个属性:

    @propertydef is_done(self) -> bool:def process(self,input_ids: torch.LongTensor,next_scores: torch.FloatTensor,next_tokens: torch.LongTensor,next_indices: torch.LongTensor,pad_token_id: Optional[int] = None,eos_token_id: Optional[Union[int, List[int]]] = None,beam_indices: Optional[torch.LongTensor] = None,) -> Tuple[torch.Tensor]:def finalize(self,input_ids: torch.LongTensor,final_beam_scores: torch.FloatTensor,final_beam_tokens: torch.LongTensor,final_beam_indices: torch.LongTensor,max_length: int,pad_token_id: Optional[int] = None,eos_token_id: Optional[Union[int, List[int]]] = None,beam_indices: Optional[torch.LongTensor] = None,) -> Tuple[torch.LongTensor]:
  • 其中is_done用来记录是否batch中所有数据都已经生成结束;

  • process是生成的每一个step都需要执行的状态更新过程,属于生成中的主干部分;

  • finalize是整个生成过程所有step都已经结束之后(出现EOS或达到stopping_criteria的终止条件),最终的后处理加工。

除此之外,这个类还有两个成员需要注意:

  • self.group_size是按照差异性对beam分组时,每一组的beam数量:
    	self.group_size = self.num_beams // self.num_beam_groups
  • self._beam_hyps是一组容器,用来容纳得分最高的 n n n个beam:
        self._beam_hyps = [BeamHypotheses(num_beams=self.num_beams,length_penalty=self.length_penalty,early_stopping=self.do_early_stopping,max_length=max_length,)for _ in range(batch_size)]

接下来在第4节中,简单介绍一下这个BeamHypotheses类。

4. BeamHypotheses

BeamHypotheses,直接翻译过来就是“假说”,这个名称很容易引起迷惑,但其实把它看做是一个容器就好了,其容纳的内容就是 n n n个得分最高的beam。batch中的每个样本,对应一个BeamHypotheses。

从构造方法可以看出,其自身除了一个self.beams用来容纳得分最高的beam之外,还有若干固有的属性:

class BeamHypotheses:def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):"""Initialize n-best list of hypotheses."""self.length_penalty = length_penalty   # 与BeamScorer的length_penalty是同一个东西,用来控制倾向于生成长序列还是短序列self.early_stopping = early_stopping   # 与BeamScorer的early_stopping是同一个,控制是否采用早停机制self.max_length = max_length           # 与BeamScorer的max_length是同一个,控制生成序列的最大长度self.num_beams = num_beams             # 与BeamScorer的num_beams是同一个,生成过程中保留多少个beamself.beams = []                        # 在生成过程中,用来容纳至多num_beams个beamself.worst_score = 1e9                 # 当前状态下最差一个beam的得分if not isinstance(self.early_stopping, bool) and self.max_length is None:raise ValueError("When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"" BeamScorer class instance at initialization time.")def __len__(self):"""Number of hypotheses in the list."""return len(self.beams)

然后看一下BeamHypotheses的两个核心方法,add和is_done:

add方法用来将一个beam(对应的容器)添加到整个列表中:

    def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):"""Add a new hypothesis to the list."""score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)if len(self) < self.num_beams or score > self.worst_score:self.beams.append((score, hyp, beam_indices))if len(self) > self.num_beams:# 如果超了设置的beam数量,则按照分数从小到大对beam进行排序# 删除分数最小的对应的beam,然后把最小的分数更新sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])del self.beams[sorted_next_scores[0][1]]self.worst_score = sorted_next_scores[1][0]else:self.worst_score = min(score, self.worst_score)

is_done方法用来判断是否所有beam都已经完成了生成:

    def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:"""If there are enough hypotheses and that none of the hypotheses being generated can become better than the worstone in the heap, then we are done with this sentence."""if len(self) < self.num_beams:return False# `True`: stop as soon as at least `num_beams` hypotheses are finishedif self.early_stopping is True:return True# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate#  when `length_penalty` is positive. See the discussion below for more details.# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565elif self.early_stopping is False:highest_attainable_score = best_sum_logprobs / cur_len**self.length_penaltyret = self.worst_score >= highest_attainable_scorereturn ret# `"never"`: compute the best possible score, depending on the signal of `length_penalty`else:# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain# its max this wayif self.length_penalty > 0.0:highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)else:highest_attainable_score = best_sum_logprobs / cur_len**self.length_penaltyret = self.worst_score >= highest_attainable_scorereturn ret

5. beam_search过程

beam_searchbeam_sample分别对应了beam_gen_modebeam_sample_gen_mode两个模式的主流程,二者的区别不是很大,先来看beam_search

    def beam_search(self,input_ids: torch.LongTensor,beam_scorer: BeamScorer,logits_processor: Optional[LogitsProcessorList] = None,stopping_criteria: Optional[StoppingCriteriaList] = None,max_length: Optional[int] = None,pad_token_id: Optional[int] = None,eos_token_id: Optional[Union[int, List[int]]] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,output_scores: Optional[bool] = None,return_dict_in_generate: Optional[bool] = None,synced_gpus: Optional[bool] = False,**model_kwargs,) -> Union[BeamSearchOutput, torch.LongTensor]:

其中这些输入参数,多数在前一篇博客中已经介绍过,这里需要注意的是BeamScorer,这个类在本文的3.2中进行了详细的介绍,它是一个用来在生成过程中,对每一个step的概率得分进行计算,并且判断生成过程是否结束。

在这个方法中,有一个while true的循环,是其主体部分,也是beam search核心逻辑的体现。在这个while之前的部分基本都是些实例化初始化的内容,理解起来没有什么困难。唯一需要额外注意的,应该是beam score的初始化问题。

5.1 beam score初始化

beam score的初始化是一个比较细节的问题,并且是新版的代码对其进行了改进。

理论上,对于beam search的过程,需要维护一个beam score来记录生成过程中每个beam的得分即可,也就是维护一个(batch_size, num_beams)的tensor,然而在代码的实现中,却有这样一个细节:

        # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens# of the first beam are considered to avoid sampling the exact same tokens across all beams.beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)beam_scores[:, 1:] = -1e9beam_scores = beam_scores.view((batch_size * num_beams,))

即batch中的每一条,对应的要生成的所有beam中,只有第1个beam的得分初始化为0,其余beam全部都初始化为-inf。代码的注释也对这样的用意进行了解释:防止在生成过程中,所有的beam产生的结果都是一样的。

这里举一个例子来对此进行说明。

假如有这样的场景,有这样的一个句子作为开头:“我的家在”,需要模型生成接下来的内容。那么在下一个step,需要根据现有的序列“我的家在”,来计算词表中所有词的得分。假如beam_size为2,那么就会保留了得分最高的两个,此时我们期望得到的两个beam可能分别为:beam 1:“我的家在东”
beam 2:“我的家在北”然后再一个step,这两个beam分别变成了:beam 1:“我的家在东北”
beam 2:“我的家在北京”然而实际情况却并非如此,实际上,每个beam是一个容器类BeamHypotheses,
在计算时,第一个beam的序列为“我的家在”时,第二个beam的序列也是“我的家在”,
这样一来,两个序列的tokens完全一致,对应的scores完全一致,后续生成的结果,也就会一直重复下去了。把第一个beam的分初始化为0,其余beam初始化为负无穷,就可以确保生成出来的只能在第一个beam对应的序列里,
就不会出现一直重复的情况了。

下面直接通过while循环来看beam search的主体逻辑。

5.2 准备输入

model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

prepare_inputs_for_generation这个方法在GenerationMixin中没有定义,需要在具体的模型中定义,举一个最简单的例子,在BART中,该方法仅仅是将(input_ids, **model_kwargs)做了简单的整理,而没有做更多的处理:

# 代码位置:
# transformers.models.bart.modeling_bart.pydef prepare_inputs_for_generation(self,decoder_input_ids,past_key_values=None,attention_mask=None,decoder_attention_mask=None,head_mask=None,decoder_head_mask=None,cross_attn_head_mask=None,use_cache=None,encoder_outputs=None,**kwargs,):# cut decoder_input_ids if past_key_values is usedif past_key_values is not None:decoder_input_ids = decoder_input_ids[:, -1:]return {"input_ids": None,  # encoder_outputs is defined. input_ids not needed"encoder_outputs": encoder_outputs,"past_key_values": past_key_values,"decoder_input_ids": decoder_input_ids,"attention_mask": attention_mask,"decoder_attention_mask": decoder_attention_mask,"head_mask": head_mask,"decoder_head_mask": decoder_head_mask,"cross_attn_head_mask": cross_attn_head_mask,"use_cache": use_cache,  # change this to avoid caching (presumably for debugging)}

而对于某些模型,则会对模型的输入提前做一些预处理,而预处理的部分就会写在prepare_inputs_for_generation中,例如ChatGLM。

5.3 前向forward

有了输入之后,自然要将输入传输给模型进行计算,也就是网络的前向传播阶段,这里的self是调用自身,也就是GenerationMixin这个类,而我们在之前的分析中知道,其实这个类是被实际调用的模型所继承的,所以实际上这里是使用了生成模型的forward方法。

            outputs = self(**model_inputs,return_dict=True,output_attentions=output_attentions,output_hidden_states=output_hidden_states,)

所以这个outputs,就是包含了loss,logits,以及可能包含attention与past_v_k等各种信息的计算结果。

还是以BART为例,在BartForConditionalGeneration可以看到,它主要就是先经过了transformer网络,得到一个形状为(seq_len, bsz, hidden)的hidden_states,然后将其映射到词表上,就得到了在词表空间上的概率分布,形状为(bsz, seq_len, vocab),也就是常说的logits。多数ConditionalGeneration模型都是这样的一个套路。

# 代码位置:
# transformers.models.bart.modeling_bart.pydef forward(self,input_ids: torch.LongTensor = None,attention_mask: Optional[torch.Tensor] = None,decoder_input_ids: Optional[torch.LongTensor] = None,decoder_attention_mask: Optional[torch.LongTensor] = None,head_mask: Optional[torch.Tensor] = None,decoder_head_mask: Optional[torch.Tensor] = None,cross_attn_head_mask: Optional[torch.Tensor] = None,encoder_outputs: Optional[List[torch.FloatTensor]] = None,past_key_values: Optional[List[torch.FloatTensor]] = None,inputs_embeds: Optional[torch.FloatTensor] = None,decoder_inputs_embeds: Optional[torch.FloatTensor] = None,labels: Optional[torch.LongTensor] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple, Seq2SeqLMOutput]:r"""labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.Returns:"""return_dict = return_dict if return_dict is not None else self.config.use_return_dictif labels is not None:if use_cache:logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")use_cache = Falseif decoder_input_ids is None and decoder_inputs_embeds is None:decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)outputs = self.model(input_ids,attention_mask=attention_mask,decoder_input_ids=decoder_input_ids,encoder_outputs=encoder_outputs,decoder_attention_mask=decoder_attention_mask,head_mask=head_mask,decoder_head_mask=decoder_head_mask,cross_attn_head_mask=cross_attn_head_mask,past_key_values=past_key_values,inputs_embeds=inputs_embeds,decoder_inputs_embeds=decoder_inputs_embeds,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)lm_logits = self.lm_head(outputs[0])lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)masked_lm_loss = Noneif labels is not None:labels = labels.to(lm_logits.device)loss_fct = CrossEntropyLoss()masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))if not return_dict:output = (lm_logits,) + outputs[1:]return ((masked_lm_loss,) + output) if masked_lm_loss is not None else outputreturn Seq2SeqLMOutput(loss=masked_lm_loss,logits=lm_logits,past_key_values=outputs.past_key_values,decoder_hidden_states=outputs.decoder_hidden_states,decoder_attentions=outputs.decoder_attentions,cross_attentions=outputs.cross_attentions,encoder_last_hidden_state=outputs.encoder_last_hidden_state,encoder_hidden_states=outputs.encoder_hidden_states,encoder_attentions=outputs.encoder_attentions,)

5.4 计算下一个step每个token的得分

在上一小节中,前向计算的结果,有很多项,其中在生成过程中,最关键的就是logits,它直接关系到下一个step生成的token是什么。

    1. logits的形状为(bsz, seq_len, vocab),所以下面代码中,第一行取的[:, -1, :],也就是取最后一个位置的概率分布,即用来生成下一个step的token。
    1. adjust_logits_during_generation是具体的模型定义的特殊方法,在生成过程中用来控制logits,如果不需要额外的控制,这个方法会默认返回logits本身。代码中的例子是在marian预训练模型中需要确保pad_token永远不被预测出,所以强行将其对应的logits设置为-inf(Marain是与BART非常类似的一个Encoder-Decoder模型,HF上最常用的翻译模型Helsinki系列就是用了这个结构)。
    1. 取log softmax(dim=-1)将logits变成vocab空间上的“概率”。
    1. 使用之前实例化的logits_processor对计算出的概率进行进一步的处理(logits_processor的介绍可以参考本文的上篇)
    1. 将processor处理之后的得分,与beam本身的得分相加算总分,即新的beam总分=原来的beam总分+即将生成的新token的分。这里可以回顾一下之前beam score初始化的细节,在while循环中的第一个step,只有第一个beam的分不是-inf,而之后的step中就不存在这个问题了。
            next_token_logits = outputs.logits[:, -1, :]# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`# cannot be generated both before and after the `nn.functional.log_softmax` operation.next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)next_token_scores = nn.functional.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)next_token_scores_processed = logits_processor(input_ids, next_token_scores)next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

这里还有一个细节,就是在最后一步中,为什么next_token_scores_processedbeam_scores可以直接相加,我的理解是,计算最大概率的beam,其基本的概率公式应该是每一个step的概率相乘:
S c o r e c u r = p 0 ∗ p 1 ∗ . . . ∗ p i = ( p 0 ∗ p 1 ∗ . . . ∗ p i − 1 ) ∗ p i = S c o r e p r e v ∗ p i Score_{cur}=p_0*p_1*...*p_i=\left(p_0*p_1*...*p_{i-1}\right)*p_i=Score_{prev}*p_i Scorecur=p0p1...pi=(p0p1...pi1)pi=Scoreprevpi
而由于在之前的得分计算中,已经取了对数,也就把原本乘性的问题变成了加性,二者自然可以直接相加了。

5.5 选择next token

在sample之前,有一个reshape的过程,将next_token_scores的形状从(batch_size * num_beams, vocab_size)变成了(batch_size, num_beams * vocab_size),也就是说,将num_beam展平在了vocab的维度上:

reshape示意
经过了这样的reshape,就把batch中的每一条样本,其包含的所有beam,放在一起进行对比了。更具体一点来讲,就是在
[
选择第1个beam的情况下,再选词表中第1个词,
选择第1个beam的情况下,再选词表中第2个词,
…,
选择第2个beam的情况下,再选词表中第1个词,
…,
选择第2个beam的情况下,再选词表中第6个词,
]
之中,选取概率最高的。由于这里介绍的是最基础的beam_gen_mode,所以还没涉及到top_k等超参数,这些部分在下文中会继续介绍。

在实际操作中,多采了一倍的token作为备选,以确保后续不会出问题。

接下来的torch.div的操作,是因为在topk之前,将beam展平在了vocab上,所以算出来的indices是在所有beam上的一个“绝对位置”,需要将它变成在每一个beam上的“相对位置”。

            # Store scores, attentions and hidden_states when requiredif return_dict_in_generate:if output_scores:scores += (next_token_scores_processed,)if output_attentions:decoder_attentions += ((outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,))if self.config.is_encoder_decoder:cross_attentions += (outputs.cross_attentions,)if output_hidden_states:decoder_hidden_states += ((outputs.decoder_hidden_states,)if self.config.is_encoder_decoderelse (outputs.hidden_states,))# reshape for beam searchvocab_size = next_token_scores.shape[-1]next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True)next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")next_tokens = next_tokens % vocab_size

5.6 更新beam状态

在这一步中,对beam的状态进行了更新,依赖BeamScorer的process方法。

            # statelessbeam_outputs = beam_scorer.process(input_ids,next_token_scores,next_tokens,next_indices,pad_token_id=pad_token_id,eos_token_id=eos_token_id,beam_indices=beam_indices,)

这里就涉及到了scorer的process部分,对此进行详细的说明:

注意自此开始,代码跳转到transformers.generation.beam_search.BeamSearchScorer.process

首先,将beam的数量,也就是BeamHypotheses容器的数量作为batch_size,对输入input_ids的形状做了检验,并且对下一个step的三项基本状态beam_scoresbeam_tokensbeam_indices进行了初始化。

注意这三项既是process方法的输入,也是process最终的输出,作为下一次process的输入。

并且,在上一节的代码中可以看到,beam_scores在传入给process之前,已经在dim=1上做了排序,也就是在vocab_size的维度。

        cur_len = input_ids.shape[-1]batch_size = len(self._beam_hyps)if not (batch_size == (input_ids.shape[0] // self.group_size)):if self.num_beam_groups > 1:raise ValueError(f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "f"size of {self.group_size} is expected by the beam scorer.")else:raise ValueError(f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "f"{self.group_size} is expected by the beam scorer.")device = input_ids.devicenext_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)if isinstance(eos_token_id, int):eos_token_id = [eos_token_id]

接下来是process的主体部分,对每一个beam_hyp(即每一个进行生成中的束)进行遍历,过程的细节以注释的形式写在了代码里,这一部分的逻辑不算复杂,但是其中也涉及到了一些由分组运算而引发的细节问题:

        for batch_idx, beam_hyp in enumerate(self._beam_hyps):# 如果当前这一束已经被标记为完成了生成,则将三项输出结果进行paddingif self._done[batch_idx]:if self.num_beams < len(beam_hyp):raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")if eos_token_id is None or pad_token_id is None:raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")# pad the batchnext_beam_scores[batch_idx, :] = 0next_beam_tokens[batch_idx, :] = pad_token_idnext_beam_indices[batch_idx, :] = 0continue# next tokens for this sentence# 如果当前这一束还没有完成,则计算这一束的下一个tokenbeam_idx = 0for beam_token_rank, (next_token, next_score, next_index) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])):# 由于在一开始输入到process时,next_tokens等,就是在vocab维度上排好序的# 所以这里只需要按顺序添加即可# 这里是将某个beam中的相对位置恢复为整个tensor中的绝对位置,注意看第5.5节中的图batch_beam_idx = batch_idx * self.group_size + next_index# add to generated hypotheses if end of sentence# 最高得分是结束符的情况if (eos_token_id is not None) and (next_token.item() in eos_token_id):# if beam_token does not belong to top num_beams tokens, it should not be added# 如果当前得分最高的是结束符则需要进行额外的一步判断# 因为在计算得分的时候是将一组中所有beam放在一起计算的,所以即便是预测到了eos,# 如果它不再前第num_beams个token范围内的话,那这个eos就不能算数is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_sizeif is_beam_token_worse_than_top_num_beams:continueif beam_indices is not None:beam_index = beam_indices[batch_beam_idx]beam_index = beam_index + (batch_beam_idx,)else:beam_index = Nonebeam_hyp.add(input_ids[batch_beam_idx].clone(),next_score.item(),beam_indices=beam_index,)else:# add next predicted token since it is not eos_token# 如果不是eos token的话,则直接添加即可next_beam_scores[batch_idx, beam_idx] = next_scorenext_beam_tokens[batch_idx, beam_idx] = next_tokennext_beam_indices[batch_idx, beam_idx] = batch_beam_idxbeam_idx += 1# once the beam for next step is full, don't add more tokens to it.if beam_idx == self.group_size:breakif beam_idx < self.group_size:raise ValueError(f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected.")# Check if we are done so that we can save a pad step if all(done)# 更新beam的完成状态cur_len += 1  # add up to the length which the next_scores is calculated onself._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(next_scores[batch_idx].max().item(), cur_len)return UserDict({"next_beam_scores": next_beam_scores.view(-1),"next_beam_tokens": next_beam_tokens.view(-1),"next_beam_indices": next_beam_indices.view(-1),})

以上就是BeamScorer的process过程,在计算出新的beam_scores等三项结果之后,还需要进行进一步的处理:

注意从这里开始,代码回到transformers.generation.utils.GenerationMixin.beam_search

这一部分代码用来更新生成参数,保存past_key_values,以及判断是否满足停止条件。

            beam_scores = beam_outputs["next_beam_scores"]beam_next_tokens = beam_outputs["next_beam_tokens"]beam_idx = beam_outputs["next_beam_indices"]input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder)if model_kwargs["past_key_values"] is not None:model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)if return_dict_in_generate and output_scores:beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))# increase cur_lencur_len = cur_len + 1if beam_scorer.is_done or stopping_criteria(input_ids, scores):if not synced_gpus:breakelse:this_peer_finished = True

5.7 后处理finalize

当生成终止后,还需要进行一个统一的后处理流程,以选择最佳的序列作为最终结果返回。

代码位于transformers.generation.beam_search.BeamSearchScorer.finalize

在这个环节中,首先需要把没有完成的beam对应的token和score添加到容器中。回顾process部分的代码,可以看到,只有当预测出eos token,并且满足一定条件时,token和score才会被添加到beam_hyp容器中,而根据beam search的整体逻辑,每个step的状态更新完成时,不管是否添加到了容器中,都需要对结束状态进行判断,而判断时,stopping_criteria就会发挥作用了。这也就会造成存在这样一种情况,还没有结束生成的beam,由于满足了stopping_criteria的中止条件,而提前中止,此时的token和score并没有被添加到beam_hyp中,所以需要这样一个后处理的动作,来确保最终得到的beam数量,等于预先设置的num_beams。

        batch_size = len(self._beam_hyps)if isinstance(eos_token_id, int):eos_token_id = [eos_token_id]# finalize all open beam hypotheses and add to generated hypothesesfor batch_idx, beam_hyp in enumerate(self._beam_hyps):if self._done[batch_idx]:continue# all open beam hypotheses are added to the beam hypothesis# beam hypothesis class automatically keeps the best beamsfor beam_id in range(self.num_beams):batch_beam_idx = batch_idx * self.num_beams + beam_idfinal_score = final_beam_scores[batch_beam_idx].item()final_tokens = input_ids[batch_beam_idx]beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else Nonebeam_hyp.add(final_tokens, final_score, beam_indices=beam_index)

然后根据score从高到低对所有的束进行排序,保留得分最高的num_beam_hyps_to_keep个束。

        # select the best hypothesessent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)best = []best_indices = []best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)# retrieve best hypothesesfor i, beam_hyp in enumerate(self._beam_hyps):sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])for j in range(self.num_beam_hyps_to_keep):best_hyp_tuple = sorted_hyps.pop()best_score = best_hyp_tuple[0]best_hyp = best_hyp_tuple[1]best_index = best_hyp_tuple[2]sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)# append hyp to listsbest.append(best_hyp)# append indices to listbest_indices.append(best_index)best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

最后,对保留的所有束进行padding,已经添加eos结束符:

        # prepare for adding eossent_lengths_max = sent_lengths.max().item() + 1sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_maxdecoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)if len(best_indices) > 0 and best_indices[0] is not None:indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)else:indices = None# shorter batches are padded if neededif sent_lengths.min().item() != sent_lengths.max().item():assert pad_token_id is not None, "`pad_token_id` has to be defined"decoded.fill_(pad_token_id)if indices is not None:indices.fill_(-1)# fill with hypotheses and eos_token_id if the latter fits infor i, (hypo, best_idx) in enumerate(zip(best, best_indices)):decoded[i, : sent_lengths[i]] = hypoif indices is not None:indices[i, : len(best_idx)] = torch.tensor(best_idx)if sent_lengths[i] < sent_max_len:# inserting only the first eos_token_iddecoded[i, sent_lengths[i]] = eos_token_id[0]return UserDict({"sequences": decoded,"sequence_scores": best_scores,"beam_indices": indices,})

以上就是beam search的完整流程了。在实际应用中,使用更多的方法一般是beam search的升级版,beam sample,在第6节中,将简单介绍一下beam sample模式与一般beam search的主要区别。

6. beam sample

Beam sample与一般的beam search相比,主要区别体现在其需要根据GenerationConfig的配置,创建若干logits warper,对计算出的next_token_scores进行进一步的加工。

从代码中来看,beam_sample方法与beam_search方法相比,区别主要在于while True的循环中,增加了logits_warper

            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)next_token_scores = nn.functional.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)next_token_scores_processed = logits_processor(input_ids, next_token_scores)next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867# 下边这一行是新增的:next_token_scores = logits_warper(input_ids, next_token_scores)# Store scores, attentions and hidden_states when requiredif return_dict_in_generate:if output_scores:# beam_search中原本是这样的:# scores += (next_token_scores_processed,)# 下边这一行是beam_sample的:scores += (logits_warper(input_ids, next_token_scores_processed),)if output_attentions:decoder_attentions += ((outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,))if self.config.is_encoder_decoder:cross_attentions += (outputs.cross_attentions,)if output_hidden_states:decoder_hidden_states += ((outputs.decoder_hidden_states,)if self.config.is_encoder_decoderelse (outputs.hidden_states,))

在本文的上篇的4.11节中,对创建的logits warper进行了简单的介绍。这里就以其中一种logits wrapper为例进行介绍。

Temperature是生成过程中一项重要超参数,它控制着生成结果是否具有“创造性”。这个数值一般介于 [ 0.1 , 1 ] [0.1, 1] [0.1,1],该值越大,越倾向于生成概率不那么高的token,结果更具有“创造性”,等于1时,相当于原始的softmax得到的分布;而该值越小,则倾向于生成更加保守的结果,当接近于0时,则趋向于greedy search。

对应的wrapper实现如下:

class TemperatureLogitsWarper(LogitsWarper):r"""[`LogitsWarper`] for temperature (exponential scaling output probability distribution).Args:temperature (`float`):The value used to module the logits distribution."""def __init__(self, temperature: float):if not isinstance(temperature, float) or not (temperature > 0):raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")self.temperature = temperaturedef __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:scores = scores / self.temperaturereturn scores

从中可以看到,它只是将原来的得分除以temperature的数值。结合logits_warper在整体流程中的位置(warper的调用位于softmax之后),可以看出这一计算并没有在当前step生效,而是在下一个step时才会生效,这也符合带temperature的softmax的公式。
原始的softmax:
s ( x i ) = exp ⁡ x i ∑ j = 0 N exp ⁡ x j s(x_i) = \frac{\exp^{x_i}} {\sum_{j=0}^N \exp^{x_j}} s(xi)=j=0Nexpxjexpxi

增加temperature之后的softmax:
s ( x i ) = exp ⁡ x i t ∑ j = 0 N exp ⁡ x j t s(x_i) = \frac{\exp^{\frac{x_i} {t}}} {\sum_{j=0}^N \exp^{\frac {x_j} {t}}} s(xi)=j=0Nexptxjexptxi

其他的warper也是类似的使用方法,是作用在softmax计算完当前step的得分之后。

7. 总结

至此,transformers模块中generate相关的使用方法就已经全部介绍清楚了,随着代码的更新升级,其中的实现细节或许会发生些许变化,但只要NLG的大框架不被推翻,生成的基本逻辑就不会发生什么大的变化。在LLM迅速发展的当下,对于多数研究人员而言,或许并没有条件从头训练一个自己专属的模型,于是,如何利用好logits processor和stopping criteria,在已有模型的基础上灵活的进行生成,从代码实现的角度,理解模型是如何生成一个完整的序列,就格外重要了。

本文的写作花费了比较大的精力,期间由于个人原因搁置了一段时间,回过头来继续编写时,发现transformers源码已经发生了较大更新,无奈只好将代码部分重写。我的博客会持续不定期更新,分享近期热门人工智能相关知识技术,以及学习和实验过程中积累的体会心得,更新频率取决于我的业余时间是否充裕。写作纯属个人兴趣,没有任何收益来源,如果本文对你的学习或工作带来了帮助,麻烦留下一个免费的赞,大家的支持就是我更新的动力。

欢迎留言讨论,如需转载,请注明出处。

这篇关于以beam search为例,详解transformers中generate方法(下)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/336524

相关文章

Window Server2016加入AD域的方法步骤

《WindowServer2016加入AD域的方法步骤》:本文主要介绍WindowServer2016加入AD域的方法步骤,包括配置DNS、检测ping通、更改计算机域、输入账号密码、重启服务... 目录一、 准备条件二、配置ServerB加入ServerA的AD域(test.ly)三、查看加入AD域后的变

Window Server2016 AD域的创建的方法步骤

《WindowServer2016AD域的创建的方法步骤》本文主要介绍了WindowServer2016AD域的创建的方法步骤,文中通过图文介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录一、准备条件二、在ServerA服务器中常见AD域管理器:三、创建AD域,域地址为“test.ly”

NFS实现多服务器文件的共享的方法步骤

《NFS实现多服务器文件的共享的方法步骤》NFS允许网络中的计算机之间共享资源,客户端可以透明地读写远端NFS服务器上的文件,本文就来介绍一下NFS实现多服务器文件的共享的方法步骤,感兴趣的可以了解一... 目录一、简介二、部署1、准备1、服务端和客户端:安装nfs-utils2、服务端:创建共享目录3、服

JAVA系统中Spring Boot应用程序的配置文件application.yml使用详解

《JAVA系统中SpringBoot应用程序的配置文件application.yml使用详解》:本文主要介绍JAVA系统中SpringBoot应用程序的配置文件application.yml的... 目录文件路径文件内容解释1. Server 配置2. Spring 配置3. Logging 配置4. Ma

Java 字符数组转字符串的常用方法

《Java字符数组转字符串的常用方法》文章总结了在Java中将字符数组转换为字符串的几种常用方法,包括使用String构造函数、String.valueOf()方法、StringBuilder以及A... 目录1. 使用String构造函数1.1 基本转换方法1.2 注意事项2. 使用String.valu

mac中资源库在哪? macOS资源库文件夹详解

《mac中资源库在哪?macOS资源库文件夹详解》经常使用Mac电脑的用户会发现,找不到Mac电脑的资源库,我们怎么打开资源库并使用呢?下面我们就来看看macOS资源库文件夹详解... 在 MACOS 系统中,「资源库」文件夹是用来存放操作系统和 App 设置的核心位置。虽然平时我们很少直接跟它打交道,但了

Python中使用defaultdict和Counter的方法

《Python中使用defaultdict和Counter的方法》本文深入探讨了Python中的两个强大工具——defaultdict和Counter,并详细介绍了它们的工作原理、应用场景以及在实际编... 目录引言defaultdict的深入应用什么是defaultdictdefaultdict的工作原理

关于Maven中pom.xml文件配置详解

《关于Maven中pom.xml文件配置详解》pom.xml是Maven项目的核心配置文件,它描述了项目的结构、依赖关系、构建配置等信息,通过合理配置pom.xml,可以提高项目的可维护性和构建效率... 目录1. POM文件的基本结构1.1 项目基本信息2. 项目属性2.1 引用属性3. 项目依赖4. 构

Rust 数据类型详解

《Rust数据类型详解》本文介绍了Rust编程语言中的标量类型和复合类型,标量类型包括整数、浮点数、布尔和字符,而复合类型则包括元组和数组,标量类型用于表示单个值,具有不同的表示和范围,本文介绍的非... 目录一、标量类型(Scalar Types)1. 整数类型(Integer Types)1.1 整数字

使用Python进行文件读写操作的基本方法

《使用Python进行文件读写操作的基本方法》今天的内容来介绍Python中进行文件读写操作的方法,这在学习Python时是必不可少的技术点,希望可以帮助到正在学习python的小伙伴,以下是Pyth... 目录一、文件读取:二、文件写入:三、文件追加:四、文件读写的二进制模式:五、使用 json 模块读写