以beam search为例,详解transformers中generate方法(上)

比起两年前,NLG任务已经得到了非常有效的发展,transformers模块的使用广泛程度也达到前所未有的程度。在模型推理预测时,一个核心的语句就是model.generate(),本文就来详细介绍一下generate方法是如何运作的。在生成的过程中,包含了诸多生成策略,本文将以最常用的beam search为例,在本人能力范围内,尽可能详细地展开介绍。

考虑到篇幅可能会比较长,本文将分为上下两篇,上篇主要介绍generate方法的整体结构,下篇将对beam search部分的代码进行进一步的介绍。

随着各种LLM的出现,transformers中与generate相关的代码发生了一些变化,主要区别在于:

    1. generate的源码位置发生了改变;
    1. generate方法中,采用一个generation_config参数来管理生成相关的各种配置,并优化了逻辑,使得逻辑更加清晰。

1. generate的代码位置

在之前版本的transformers中(transformers~=4.9),generate方法位于transformers.generation_utils.py,这个方法是GenerationMixin类的一个方法。

而在新版本的transformers中(transformers~=4.28),generate方法被转移到了transformers.generation.utils.py,仍然是GenerationMixin的一个类方法。

而对于一个hf形式的预训练模型,都是继承了PreTrainedModel类的,而顺着这个PreTrainedModel类,可以看到更上一级的继承逻辑,GenerationMixin就在其中:

class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):

这就是为什么通过AutoModel.from_pretrained()实例化的一个model为什么可以直接调用generate方法去做推理。

2. GenerationMixin概览

这一部分作为一个速查表写在这里,不建议直接阅读,而是在读第4节代码的过程中,返回来查看这部分内容。

GenerationMixin类所有方法概览如下:

方法名 作用 在本文中出现的位置
_validate_model_class 检修该模型是否可以做生成,并抛出相应的异常 4.1
_validate_model_kwargs 检查generation config中的参数是否与生成策略相匹配 4.1
_prepare_model_inputs 为生成过程准备输入 4.3
_maybe_initialize_input_ids_for_generation 当生成过程的inputs为空时,使用bos token做初始化 4.3
_prepare_attention_mask_for_generation 为生成过程准备attention_mask 4.4
_prepare_encoder_decoder_kwargs_for_generation 为生成过程准备encoder相关的参数 4.4
_prepare_decoder_input_ids_for_generation 为自回归模型额外处理input_ids 4.5
_get_decoder_start_token_id 获取decoder的开始位置的token id,这个id可能与bos不同 4.5
_get_logits_processor 创建logits处理器 4.8
_get_stopping_criteria 创建停止规则 4.9
_get_logits_warper 创建logits warper 4.11
_expand_inputs_for_generation 根据num_beams对input_ids进行扩展 4.12
prepare_inputs_for_generation 对模型的输入进行预处理 下篇3.1
adjust_logits_during_generation 在生成过程中对计算的logits进行调整 下篇3.1
_update_model_kwargs_for_generation 根据一个step的生成结果,更新生成参数 下篇5.6
_reorder_cache 根据step更新的beam_idx,对缓存的past_k_v进行重排 下篇5.6

3. generate签名

在介绍流程之前先看一下generate方法的签名,在4.28版本中,其签名简化如下:

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        synced_gpus: Optional[bool] = None,
        streamer: Optional["BaseStreamer"] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:

相比之前的版本,这样写的直接优点就是,与原版的超长签名相比,减少了传入的参数,将诸如top_k, top_p, num_beams等参数全部都整合到了generation_config中,使得函数看起来更加简化,并且该参数可以直接从模型路径下的generation_config.json文件中读取,一定程度上为用户提供了便捷。

相应的缺点就是很多参数没有显性地暴露出来,在查看注释和自定义生成配置的时候就不是很方便了。
需要在GenerationConfig中查看可选的参数:

from transformers.generation.configuration_utils import GenerationConfig

help(GenerationConfig)

GenerationConfig中各类生成策略对应的参数各有不同,这里不展开介绍,在本文的下篇中,将对beam search策略下的参数进行简介。)


generate方法的参数含义与作用介绍如下:

参数名 类型 含义与作用
inputs torch.Tensor tokenize之后的序列id,模型将基于这个序列计算logits并进行生成。如果为空,则默认为bos token对应的id
generation_config GenerationConfig 各种生成策略对应的参数,如果为空,将会从模型路径下的generation_config.json文件中读取,或从model config获取
logits_processor LogitsProcessorList 对模型计算出的logits进行进一步处理,例如对“复读机现象”相应的概率进行惩罚,以避免模型生成结果不断重复
stopping_criteria StoppingCriteriaList 对生成过程做停止控制的工具,例如达到一定长度时强行停止,达到一定生成时间时停止等
prefix_allowed_tokens_fn [int, torch.Tensor], List[int] beam search过程中,每个step允许生成的token id范围
synced_gpus bool 采用DeepSpeed ZeRO时使用
streamer BaseStreamer stream generate时使用(也就是一个字一个字的往外蹦的效果)

在这些输入中,logits_processor和stopping_criteria,将是用户手动干预生成过程的主要手段。

4. generate过程

在4.28版本的transformers代码中,generate过程的注释写的比较条理清晰,所以本文也沿用代码注释中的序号进行划分。

4.1 读取并更新generation config

这一部分的大概逻辑就是处理generation config为None的情况,以及检查是否存在与生成策略不一致的错误参数。

        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
        self._validate_model_class()

        # priority: `generation_config` argument > `model.generation_config` (the default generation config)
        if generation_config is None:
            # legacy: users may modify the model configuration to control generation -- update the generation config
            # model attribute accordingly, if it was created from the model config
            if self.generation_config._from_model_config:
                new_generation_config = GenerationConfig.from_model_config(self.config)
                if new_generation_config != self.generation_config:
                    warnings.warn(
                        "You have modified the pretrained model configuration to control generation. This is a"
                        " deprecated strategy to control generation and will be removed soon, in a future version."
                        " Please use a generation configuration file (see"
                        " https://huggingface.co/docs/transformers/main_classes/text_generation)"
                    )
                    self.generation_config = new_generation_config
            generation_config = self.generation_config

        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
        generation_config.validate()
        self._validate_model_kwargs(model_kwargs.copy())

其中_validate_model_class_validate_model_kwargs两个方法都不是重点,这里不展开介绍。

4.2 补充没有传入的参数

这部分需要补充的参数包括logits_processor, stopping_criteria, 以及generation_config中的pad_token_id。前两项是设置为默认的空list;pad_token_id没有给定,而eos给定的话,用eos来做padding。

        # 2. Set generation parameters if not already defined
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
            if model_kwargs.get("attention_mask", None) is None:
                logger.warning(
                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
                )
            eos_token_id = generation_config.eos_token_id
            if isinstance(eos_token_id, list):
                eos_token_id = eos_token_id[0]
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{
      
      eos_token_id} for open-end generation.")
            generation_config.pad_token_id = eos_token_id

4.3 定义模型输入

        # 3. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
            inputs, generation_config.bos_token_id, model_kwargs
        )
        batch_size = inputs_tensor.shape[0]

这里主要需要关注_prepare_model_inputs这个方法,这个方法的核心,一句话概括就是模型输入的序列input_ids,必须非空,如果空的话,就用bos_token去初始化。其余部分都是用来应对个别模型的特殊情况:

def _prepare_model_inputs(
        self,
        inputs: Optional[torch.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
        """
        This function extracts the model-specific `inputs` for generation.
        """
        # 这一步似乎是起到一个校准的作用,防止某些encoder-decoder模型的主模型和encoder的输入名称不一致
        # 1. retrieve all kwargs that are non-None or non-model input related.
        # some encoder-decoder models have different names for model and encoder
        if (
            self.config.is_encoder_decoder
            and hasattr(self, "encoder")
            and self.encoder.main_input_name != self.main_input_name
        ):
            input_name = self.encoder.main_input_name
        else:
            input_name = self.main_input_name

        model_kwargs = {
    
    k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
        
        # 确保inputs没有重复传入
        # 2. check whether model_input_name is passed as kwarg
        # if yes and `inputs` is None use kwarg inputs
        inputs_kwarg = model_kwargs.pop(input_name, None)
        if inputs_kwarg is not None and inputs is not None:
            raise ValueError(
                f"`inputs`: {
      
      inputs}` were passed alongside {
      
      input_name} which is not allowed."
                f"Make sure to either pass {
      
      inputs} or {
      
      input_name}=..."
            )
        elif inputs_kwarg is not None:
            inputs = inputs_kwarg
		
		# 对于inputs_embeds这一输入参数:
		# 如果是decoder-only模型,需要把'input_ids'这一参数放在inputs_kwarg中传入
		# 如果是encoder-decoder模型,input_ids与inputs_embeds只能传入其一
        # 3. In the presence of `inputs_embeds` for text models:
        # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
        # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
        # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
        # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
        # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
        if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
            if not self.config.is_encoder_decoder:
                has_inputs_embeds_forwarding = "inputs_embeds" in set(
                    inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
                )
                if not has_inputs_embeds_forwarding:
                    raise ValueError(
                        f"You passed `inputs_embeds` to `.generate()`, but the model class {
      
      self.__class__.__name__} "
                        "doesn't have its forwarding implemented. See the GPT2 implementation for an example "
                        "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
                    )
                # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
                # the attention mask) can rely on the actual model input.
                model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
                    inputs, bos_token_id, model_kwargs=model_kwargs
                )
            else:
                if inputs is not None:
                    raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
            inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

        # 4. if `inputs` is still None, try to create `input_ids` from BOS token
        # 如果最后还是没有input_ids, 采用bos创建input_ids,可以简化理解为:
        # torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
        inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
        return inputs, input_name, model_kwargs

4.4 定义模型的其他参数

这一部分没有需要特别注意的地方,主要就是一些config设置,补齐模型的其他参数,如创建attention_mask,确保encoder-decoder模型能够返回’ModelOutput’类等等。

        # 4. Define other model kwargs
        model_kwargs["output_attentions"] = generation_config.output_attentions
        model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
        model_kwargs["use_cache"] = generation_config.use_cache

        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs

        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
                inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
            )

        # decoder-only models should use left-padding for generation
        if not self.config.is_encoder_decoder:
            if (
                generation_config.pad_token_id is not None
                and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
            ):
                logger.warning(
                    "A decoder-only architecture is being used, but right-padding was detected! For correct "
                    "generation results, please set `padding_side='left'` when initializing the tokenizer."
                )

        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
            # if model is encoder decoder encoder_outputs are created
            # and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name
            )

4.5 对自回归模型准备input_ids

这一步与4.3的主要区别在于,针对AR模型额外进行了处理。如果是encoder-decoder模型,则直接采用4.3创建的input_tensor作为input_ids。

        # 5. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
        	# 这里主要是针对decoder的开始位置id与bos id不同的情况
        	# 在这种情况下,decoder-only模型应当以配置中规定的decoder start id开始进行生成
        	# 此处可简单理解为:torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
            input_ids = self._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=generation_config.decoder_start_token_id,
                bos_token_id=generation_config.bos_token_id,
                model_kwargs=model_kwargs,
                device=inputs_tensor.device,
            )

            # conditional generation for multi-modal models.
            if "input_ids" in model_kwargs and model_input_name == "pixel_values":
                input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1)
        else:
            input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

4.6 准备最大长度

这一部分就是根据config中的相关配置,判断input_id的长度有没有超长。

        # 6. Prepare `max_length` depending on other stopping criteria.
        input_ids_seq_length = input_ids.shape[-1]
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        if has_default_max_length and generation_config.max_new_tokens is None:
            warnings.warn(
                f"Using `max_length`'s default ({
      
      generation_config.max_length}) to control the generation length. "
                "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
                " recommend using `max_new_tokens` to control the maximum length of the generation.",
                UserWarning,
            )
        elif generation_config.max_new_tokens is not None:
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
            if not has_default_max_length:
                logger.warn(
                    f"Both `max_new_tokens` (={
      
      generation_config.max_new_tokens}) and `max_length`(="
                    f"{
      
      generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                    "Please refer to the documentation for more information. "
                    "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
                    UserWarning,
                )

        if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
            raise ValueError(
                f"Unfeasible length constraints: the minimum length ({
      
      generation_config.min_length}) is larger than"
                f" the maximum length ({
      
      generation_config.max_length})"
            )
        if input_ids_seq_length >= generation_config.max_length:
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
                f"Input length of {
      
      input_ids_string} is {
      
      input_ids_seq_length}, but `max_length` is set to"
                f" {
      
      generation_config.max_length}. This can lead to unexpected behavior. You should consider"
                " increasing `max_new_tokens`."
            )

4.7 确认生成模式

这里直接选择beam search分支了,其他模式不做展开介绍,下同。

beam search分为两种,基础款的beam_gen_mode,以及进阶款的beam_sample_gen_mode,其中,前者对应后续的生成方法为beam_search,后者对应后续的生成方法为beam_sample

二者的区别主要在于,进阶款的beam_sample_gen_mode可以设置temperature、top_k、top_p等参数进一步控制生成,设置的方法在4.11节:logits warper中介绍。对于基础款的beam_gen_mode,就没有创建logits warper这一环节。

        # 7. determine generation mode
        is_beam_sample_gen_mode = (
            (generation_config.num_beams > 1)
            and (generation_config.num_beam_groups == 1)
            and generation_config.do_sample is True
            and not is_constraint_gen_mode
            and not is_contrastive_search_gen_mode
        )

4.8 创建logits处理器

        # 8. prepare distribution pre_processing samplers
        logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            encoder_input_ids=inputs_tensor,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
        )

这一个环节比较重要,因为涉及到了logits processor。这些processor是在生成的过程中,在每一个step,对计算出来的得分进行修正处理的。在transformers中,预设了若干processor,用户也可以定义自己的processor(需要继承抽象类transformers.generation.logit_process.LogitsProcessor),自己设计逻辑,来对生成的过程进行人工干预。

在beam search中,logits process的使用方法是:

# 在def beam_sample中使用
next_token_scores_processed = logits_processor(input_ids, next_token_scores)

其中,input_ids是当前step传给模型的序列token id对应Tensor(batch_size, sequence_length),next_token_scores是经过模型计算之后的分数(即在vocab上的概率分布)取log_softmax。

在这里简单介绍一下在transformers中预设的processor。限于篇幅,不贴出全部源码,只对其功能进行总结。

processor 作用 参考连接
MinLengthLogitsProcessor 通过将EOS的概率强行设置为0,保证生成结果的长度大于等于一个最小值 /
MinNewTokensLengthLogitsProcessor 与上一个类似,但是prompt的部分不计入生成长度 /
RepetitionPenaltyLogitsProcessor 防止“复读机”现象,给重复出现token添加惩罚,由预训练模型CTRL提出 arxiv
EncoderRepetitionPenaltyLogitsProcessor 与上一个区别在于,生成的结果不能与encoder输入input id重复,而非与当前给定的全部input id /
NoRepeatNGramLogitsProcessor 防止生成的文本中出现重复的n-gram(n个连续的词或字符),区别在于限制连续n个 github
EncoderNoRepeatNGramLogitsProcessor n-gram可以在encoder的input ids中重复,不可以在decoder重复 github
NoBadWordsLogitsProcessor 确保某些词永远不会被生成 /
PrefixConstrainedLogitsProcessor 给定一个prefix_allow_func来限制符合哪些条件的token可以被生成 arxiv
HammingDiversityLogitsProcessor 以Hamming距离为标准,确保生成的各个beam之前的区别足够大 arxiv
ForcedBOSTokenLogitsProcessor 确保生成的第一个token是某个特定的token /
ForcedEOSTokenLogitsProcessor 达到最大长度时,确保以某个特定的token作为结束 /
InfNanRemoveLogitsProcessor 将计算出的得分中,nan替换为0,inf替换为可计算的最大值 /
SuppressTokensAtBeginLogitsProcessor 在达到某个长度之后,将不再生成某些特定的词 /
SuppressTokensLogitsProcessor 将某些特定词的概率设置为-inf,不生成这些词 /
ForceTokensLogitsProcessor 建立一个映射表,把某个token强行映射成另一个token /
WhisperTimeStampLogitsProcessor 强制模型生成时间戳(时间戳是一个特殊token,例如对话中,query=今天是周几?,answer=今天是[timestamp],这个[timestamp]后期会替换成对应的时间) /

4.9 创建停止规则

stopping_criteria与logits_processor是用户对生成过程进行干预的主要手段,相比logits_processor强行改变概率空间,stopping_criteria则是直接设定了终止生成的策略,理解起来也会相对容易一些。

        # 9. prepare stopping criteria
        stopping_criteria = self._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria
        )

预设的criteria总结如下:

criteria 作用
MaxLengthCriteria 生成的序列达到设置的最大长度时,停止生成
MaxNewTokensCriteria 生成的序列中,除去prompt的部分达到设置的最大长度时,停止生成
MaxTimeCriteria 生成的耗时超过一定时间限制时,停止生成

如果是自定义criteria,应当继承抽象类transformers.generation.stopping_criteria.StoppingCriteria

4.10 进入相应的分支

这里直接选择进入beam search的分支。如前文所述,如果要控制temperature等超参数,则应该进入is_beam_sample_gen_mode这个分支。

4.11 创建logits warper

            # 11. prepare logits warper
            logits_warper = self._get_logits_warper(generation_config)

logits warper的使用方法与logits processor一样,都是用来修改概率的输出。关于他们的区别,暂时没有找到很好的解释,可以理解为warper控制着temperature、topk等与生成策略相关的参数。并且是在logits processor处理之后再进行处理的。

普通的beam search不会涉及这一部分,只有选择sample模式的beam search时,才会使用到logits warper。

需要记住的是,它的输入与processor一样,都是当前的序列(token_ids)与之前计算出的得分(scores),返回的结果是处理之后的得分,形状是(batch_size, config.vocab_size)

预设的warper包括:

warper 作用(仅供参考) 参考链接
TemperatureLogitsWarper 对score整体除以temperature做缩放 /
TopPLogitsWarper 概率小于topp的得分置为0 /
TopKLogitsWarper 只取topk的概率对应的词汇,其余的概率置为-inf /
TypicalLogitsWarper typical decoding arxiv
EpsilonLogitsWarper 将概率小于epsilon的token移除 arxiv
EtaLogitsWarper eta-sampling arxiv
LogitNormalization 在beam search进行的过程中做layernorm /

4.12 beam search

这一部分是beam search的核心流程,限于篇幅,其具体的执行生成过程将在本文的下篇中进行详细的介绍。

在这一部分中,首先创建了用于打分的BeamSearchScorer(具体作用将在下篇中进行介绍),然后根据num_beams对input_ids进行了扩展,最后执行beam search的核心方法beam_search,或beam sample对应的beam_sample方法。

            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=generation_config.num_beams,
                device=inputs_tensor.device,
                length_penalty=generation_config.length_penalty,
                do_early_stopping=generation_config.early_stopping,
                num_beam_hyps_to_keep=generation_config.num_return_sequences,
                max_length=generation_config.max_length,
            )
            # 12. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
            # 13. run beam search
            return self.beam_search(
                input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=generation_config.pad_token_id,
                eos_token_id=generation_config.eos_token_id,
                output_scores=generation_config.output_scores,
                return_dict_in_generate=generation_config.return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

在本文的下篇中,将介绍beam search的基本原理,transformers模块对于beam search的实现方法中,主要涉及的几个工具组件,beam search的生成与更新过程,以及beam sample对beam search的改进实现,感兴趣的朋友可以继续阅读。

猜你喜欢

转载自blog.csdn.net/weixin_44826203/article/details/129928897
今日推荐