Taking beam search as an example, explain the generate method in transformers in detail (Part 1)

Compared with two years ago, the NLG task has been developed very effectively, and the transformers module is used more widely than ever before. In model reasoning and prediction, a core statement is model.generate(), this article will introduce in detail how the generate method works. During the generation process, many generation strategies are included. This article will take the most commonly used beam search as an example, and introduce it in as much detail as possible within the scope of my ability.

Considering that the length may be relatively long, this article will be divided into two parts. The first part mainly introduces the overall structure of the generate method, and the second part will further introduce the code of the beam search part.

With the emergence of various LLMs, some changes have taken place in the code related to generate in transformers. The main differences are:

    1. The source code location of generate has changed;
    1. In the generate method, a generation_config parameter is used to manage various configurations related to generation, and the logic is optimized to make the logic clearer.

1. The code location of generate

In previous versions of transformers (transformers~=4.9), the generate method is located transformers.generation_utils.py, which is GenerationMixina method of the class.

In the new version of transformers (transformers~=4.28), the generate method is transferred to transformers.generation.utils.py, and it is still GenerationMixina class method.

For a pre-training model in the form of hf, it inherits PreTrainedModelthe class, and along this PreTrainedModelclass, you can see the higher-level inheritance logic, GenerationMixinwhich is in it:

class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):

This is why AutoModel.from_pretrained()a model that is instantiated can directly call generatemethods to do inference.

2. Overview of GenerationMixin

This part is written here as a cheat sheet. It is not recommended to read it directly, but to go back and view this part during the process of reading the code in Section 4.

GenerationMixinAn overview of all methods of the class is as follows:

method name effect Where it appears in this article
_validate_model_class Check whether the model can be generated, and throw the corresponding exception 4.1
_validate_model_kwargs Check that the parameters in the generation config match the generation strategy 4.1
_prepare_model_inputs Prepare input for build process 4.3
_maybe_initialize_input_ids_for_generation When the inputs of the generation process are empty, use the bos token for initialization 4.3
_prepare_attention_mask_for_generation Prepare attention_mask for the generation process 4.4
_prepare_encoder_decoder_kwargs_for_generation Prepare encoder-related parameters for the generation process 4.4
_prepare_decoder_input_ids_for_generation Additional handling of input_ids for autoregressive models 4.5
_get_decoder_start_token_id Get the token id of the starting position of the decoder, this id may be different from the bos 4.5
_get_logits_processor Create a logits processor 4.8
_get_stopping_criteria Create a stop rule 4.9
_get_logits_warper Create logits warper 4.11
_expand_inputs_for_generation Expand input_ids according to num_beams 4.12
prepare_inputs_for_generation Preprocess the input to the model Part 3.1
adjust_logits_during_generation Make adjustments to the calculated logits during generation Part 3.1
_update_model_kwargs_for_generation Update the generation parameters according to the generation result of a step Part 5.6
_reorder_cache According to the beam_idx updated by step, rearrange the cached past_k_v Part 5.6

3. generate signature

Before introducing the process, let's take a look at the signature of the generate method. In version 4.28, the signature is simplified as follows:

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        synced_gpus: Optional[bool] = None,
        streamer: Optional["BaseStreamer"] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:

Compared with the previous version, the direct advantage of writing this way is that compared with the original super long signature, it reduces the parameters passed in, and integrates all parameters such as top_k, top_p, and so on , making the function look more simplified, and the Parameters can be read directly from the generation_config.json file under the model path, which provides convenience to users to a certain extent.num_beamsgeneration_config

The corresponding disadvantage is that many parameters are not explicitly exposed, which is not very convenient when viewing comments and customizing configurations. Optional parameters
need to be viewed in :GenerationConfig

from transformers.generation.configuration_utils import GenerationConfig

help(GenerationConfig)

( GenerationConfigThe parameters corresponding to various generation strategies are different, so we will not introduce them here. In the next part of this article, we will briefly introduce the parameters under the beam search strategy.)


The meaning and function of the parameters of the generate method are introduced as follows:

parameter name type meaning and function
inputs torch.Tensor The sequence id after tokenize, the model will calculate and generate logits based on this sequence. If it is empty, it defaults to the id corresponding to the bos token
generation_config GenerationConfig The parameters corresponding to various generation strategies, if empty, will be read from the generation_config.json file under the model path, or obtained from model config
logits_processor LogitsProcessorList Further process the logits calculated by the model, such as punishing the corresponding probability of the "repeater phenomenon" to avoid repeated model generation results
stopping_criteria StoppingCriteriaList A tool for stopping the generation process, such as forcibly stopping when a certain length is reached, stopping when a certain generation time is reached, etc.
prefix_allowed_tokens_fn [int, torch.Tensor], List[int] During the beam search process, the range of token ids allowed to be generated for each step
synced_gpus bool Used when using DeepSpeed ​​ZeRO
streamer BaseStreamer Used when stream generate (that is, the effect of jumping out word by word)

Among these inputs, logits_processor and stopping_criteria, will be the main means for users to manually intervene in the generation process.

4. generate process

In the transformers code of version 4.28, the comments of the generate process are written more clearly, so this article also uses the serial numbers in the code comments for division.

4.1 Read and update generation config

The general logic of this part is to deal with the case where the generation config is None, and to check whether there are any wrong parameters inconsistent with the generation strategy.

        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
        self._validate_model_class()

        # priority: `generation_config` argument > `model.generation_config` (the default generation config)
        if generation_config is None:
            # legacy: users may modify the model configuration to control generation -- update the generation config
            # model attribute accordingly, if it was created from the model config
            if self.generation_config._from_model_config:
                new_generation_config = GenerationConfig.from_model_config(self.config)
                if new_generation_config != self.generation_config:
                    warnings.warn(
                        "You have modified the pretrained model configuration to control generation. This is a"
                        " deprecated strategy to control generation and will be removed soon, in a future version."
                        " Please use a generation configuration file (see"
                        " https://huggingface.co/docs/transformers/main_classes/text_generation)"
                    )
                    self.generation_config = new_generation_config
            generation_config = self.generation_config

        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
        generation_config.validate()
        self._validate_model_kwargs(model_kwargs.copy())

Neither of them _validate_model_classand _validate_model_kwargsthe two methods are the key points, and they will not be introduced here.

4.2 Supplement the parameters that are not passed in

The parameters that need to be supplemented in this part include logits_processor, stopping_criteria, and generation_configin pad_token_id. The first two items are set as the default empty list; pad_token_id is not given, but if eos is given, use eos for padding.

        # 2. Set generation parameters if not already defined
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
            if model_kwargs.get("attention_mask", None) is None:
                logger.warning(
                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
                )
            eos_token_id = generation_config.eos_token_id
            if isinstance(eos_token_id, list):
                eos_token_id = eos_token_id[0]
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{
      
      eos_token_id} for open-end generation.")
            generation_config.pad_token_id = eos_token_id

4.3 Defining model inputs

        # 3. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
            inputs, generation_config.bos_token_id, model_kwargs
        )
        batch_size = inputs_tensor.shape[0]

Here we mainly need to pay attention to _prepare_model_inputsthis method. The core of this method, in one sentence, is that the sequence input_ids of the model input must be non-empty. If it is empty, it will be initialized with bos_token. The rest are for special cases of individual models:

def _prepare_model_inputs(
        self,
        inputs: Optional[torch.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
        """
        This function extracts the model-specific `inputs` for generation.
        """
        # 这一步似乎是起到一个校准的作用,防止某些encoder-decoder模型的主模型和encoder的输入名称不一致
        # 1. retrieve all kwargs that are non-None or non-model input related.
        # some encoder-decoder models have different names for model and encoder
        if (
            self.config.is_encoder_decoder
            and hasattr(self, "encoder")
            and self.encoder.main_input_name != self.main_input_name
        ):
            input_name = self.encoder.main_input_name
        else:
            input_name = self.main_input_name

        model_kwargs = {
    
    k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
        
        # 确保inputs没有重复传入
        # 2. check whether model_input_name is passed as kwarg
        # if yes and `inputs` is None use kwarg inputs
        inputs_kwarg = model_kwargs.pop(input_name, None)
        if inputs_kwarg is not None and inputs is not None:
            raise ValueError(
                f"`inputs`: {
      
      inputs}` were passed alongside {
      
      input_name} which is not allowed."
                f"Make sure to either pass {
      
      inputs} or {
      
      input_name}=..."
            )
        elif inputs_kwarg is not None:
            inputs = inputs_kwarg
		
		# 对于inputs_embeds这一输入参数:
		# 如果是decoder-only模型,需要把'input_ids'这一参数放在inputs_kwarg中传入
		# 如果是encoder-decoder模型,input_ids与inputs_embeds只能传入其一
        # 3. In the presence of `inputs_embeds` for text models:
        # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
        # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
        # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
        # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
        # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
        if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
            if not self.config.is_encoder_decoder:
                has_inputs_embeds_forwarding = "inputs_embeds" in set(
                    inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
                )
                if not has_inputs_embeds_forwarding:
                    raise ValueError(
                        f"You passed `inputs_embeds` to `.generate()`, but the model class {
      
      self.__class__.__name__} "
                        "doesn't have its forwarding implemented. See the GPT2 implementation for an example "
                        "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
                    )
                # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
                # the attention mask) can rely on the actual model input.
                model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
                    inputs, bos_token_id, model_kwargs=model_kwargs
                )
            else:
                if inputs is not None:
                    raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
            inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

        # 4. if `inputs` is still None, try to create `input_ids` from BOS token
        # 如果最后还是没有input_ids, 采用bos创建input_ids,可以简化理解为:
        # torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
        inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
        return inputs, input_name, model_kwargs

4.4 Defining other parameters of the model

There is no need to pay special attention to this part, mainly some config settings to complete other parameters of the model, such as creating attention_mask, ensuring that the encoder-decoder model can return the 'ModelOutput' class, and so on.

        # 4. Define other model kwargs
        model_kwargs["output_attentions"] = generation_config.output_attentions
        model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
        model_kwargs["use_cache"] = generation_config.use_cache

        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs

        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
                inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
            )

        # decoder-only models should use left-padding for generation
        if not self.config.is_encoder_decoder:
            if (
                generation_config.pad_token_id is not None
                and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
            ):
                logger.warning(
                    "A decoder-only architecture is being used, but right-padding was detected! For correct "
                    "generation results, please set `padding_side='left'` when initializing the tokenizer."
                )

        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
            # if model is encoder decoder encoder_outputs are created
            # and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name
            )

4.5 Prepare input_ids for autoregressive models

The main difference between this step and 4.3 is that additional processing is performed for the AR model. If it is an encoder-decoder model, directly use the input_tensor created in 4.3 as input_ids.

        # 5. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
        	# 这里主要是针对decoder的开始位置id与bos id不同的情况
        	# 在这种情况下,decoder-only模型应当以配置中规定的decoder start id开始进行生成
        	# 此处可简单理解为:torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
            input_ids = self._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=generation_config.decoder_start_token_id,
                bos_token_id=generation_config.bos_token_id,
                model_kwargs=model_kwargs,
                device=inputs_tensor.device,
            )

            # conditional generation for multi-modal models.
            if "input_ids" in model_kwargs and model_input_name == "pixel_values":
                input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1)
        else:
            input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

4.6 Prepare the maximum length

This part is to judge whether the length of input_id is too long according to the relevant configuration in config.

        # 6. Prepare `max_length` depending on other stopping criteria.
        input_ids_seq_length = input_ids.shape[-1]
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        if has_default_max_length and generation_config.max_new_tokens is None:
            warnings.warn(
                f"Using `max_length`'s default ({
      
      generation_config.max_length}) to control the generation length. "
                "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
                " recommend using `max_new_tokens` to control the maximum length of the generation.",
                UserWarning,
            )
        elif generation_config.max_new_tokens is not None:
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
            if not has_default_max_length:
                logger.warn(
                    f"Both `max_new_tokens` (={
      
      generation_config.max_new_tokens}) and `max_length`(="
                    f"{
      
      generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                    "Please refer to the documentation for more information. "
                    "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
                    UserWarning,
                )

        if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
            raise ValueError(
                f"Unfeasible length constraints: the minimum length ({
      
      generation_config.min_length}) is larger than"
                f" the maximum length ({
      
      generation_config.max_length})"
            )
        if input_ids_seq_length >= generation_config.max_length:
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
                f"Input length of {
      
      input_ids_string} is {
      
      input_ids_seq_length}, but `max_length` is set to"
                f" {
      
      generation_config.max_length}. This can lead to unexpected behavior. You should consider"
                " increasing `max_new_tokens`."
            )

4.7 Confirm Generation Mode

The beam search branch is directly selected here, and other modes will not be introduced, the same below.

There are two types of beam search, the basic one beam_gen_modeand the advanced one beam_sample_gen_mode. The former corresponds to the subsequent generation method of beam_search, and the latter corresponds to the subsequent generation method of beam_sample.

The main difference between the two is that the advanced model beam_sample_gen_modecan set parameters such as temperature, top_k, and top_p to further control the generation. The setting method is introduced in section 4.11: logits warper . For the basic model beam_gen_mode, there is no link to create logits warper.

        # 7. determine generation mode
        is_beam_sample_gen_mode = (
            (generation_config.num_beams > 1)
            and (generation_config.num_beam_groups == 1)
            and generation_config.do_sample is True
            and not is_constraint_gen_mode
            and not is_contrastive_search_gen_mode
        )

4.8 Create a logits processor

        # 8. prepare distribution pre_processing samplers
        logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            encoder_input_ids=inputs_tensor,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
        )

This link is more important because it involves the logits processor. These processors correct the calculated scores at each step during the generation process. In transformers, several processors are preset, and users can also define their own processors (need to inherit the abstract class transformers.generation.logit_process.LogitsProcessor), and design their own logic to manually intervene in the generated process.

In beam search, the logits process is used as follows:

# 在def beam_sample中使用
next_token_scores_processed = logits_processor(input_ids, next_token_scores)

Among them, input_ids is the sequence token id passed to the model by the current step corresponding to Tensor (batch_size, sequence_length), and next_token_scores is the score calculated by the model (that is, the probability distribution on vocab), taking log_softmax.

Here is a brief introduction to transformersthe processor preset in . Due to space limitations, all the source code is not posted, only its functions are summarized.

processor effect reference connection
MinLengthLogitsProcessor By forcibly setting the probability of EOS to 0, ensure that the length of the generated result is greater than or equal to a minimum value /
MinNewTokensLengthLogitsProcessor Similar to the previous one, but the part of the prompt is not included in the generated length /
RepetitionPenaltyLogitsProcessor To prevent the "repeater" phenomenon, add penalties to repeated tokens, proposed by the pre-training model CTRL arxiv
EncoderRepetitionPenaltyLogitsProcessor The difference from the previous one is that the generated result cannot be duplicated with the input id of the encoder, rather than with all currently given input ids /
NoRepeatNGramLogitsProcessor Prevent repeated n-grams (n consecutive words or characters) from appearing in the generated text, the difference is that n consecutive words are limited github
EncoderNoRepeatNGramLogitsProcessor n-gram可以在encoder的input ids中重复,不可以在decoder重复 github
NoBadWordsLogitsProcessor 确保某些词永远不会被生成 /
PrefixConstrainedLogitsProcessor 给定一个prefix_allow_func来限制符合哪些条件的token可以被生成 arxiv
HammingDiversityLogitsProcessor 以Hamming距离为标准,确保生成的各个beam之前的区别足够大 arxiv
ForcedBOSTokenLogitsProcessor 确保生成的第一个token是某个特定的token /
ForcedEOSTokenLogitsProcessor 达到最大长度时,确保以某个特定的token作为结束 /
InfNanRemoveLogitsProcessor 将计算出的得分中,nan替换为0,inf替换为可计算的最大值 /
SuppressTokensAtBeginLogitsProcessor 在达到某个长度之后,将不再生成某些特定的词 /
SuppressTokensLogitsProcessor 将某些特定词的概率设置为-inf,不生成这些词 /
ForceTokensLogitsProcessor 建立一个映射表,把某个token强行映射成另一个token /
WhisperTimeStampLogitsProcessor 强制模型生成时间戳(时间戳是一个特殊token,例如对话中,query=今天是周几?,answer=今天是[timestamp],这个[timestamp]后期会替换成对应的时间) /

4.9 创建停止规则

stopping_criteria与logits_processor是用户对生成过程进行干预的主要手段,相比logits_processor强行改变概率空间,stopping_criteria则是直接设定了终止生成的策略,理解起来也会相对容易一些。

        # 9. prepare stopping criteria
        stopping_criteria = self._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria
        )

预设的criteria总结如下:

criteria 作用
MaxLengthCriteria 生成的序列达到设置的最大长度时,停止生成
MaxNewTokensCriteria 生成的序列中,除去prompt的部分达到设置的最大长度时,停止生成
MaxTimeCriteria 生成的耗时超过一定时间限制时,停止生成

如果是自定义criteria,应当继承抽象类transformers.generation.stopping_criteria.StoppingCriteria

4.10 进入相应的分支

这里直接选择进入beam search的分支。如前文所述,如果要控制temperature等超参数,则应该进入is_beam_sample_gen_mode这个分支。

4.11 创建logits warper

            # 11. prepare logits warper
            logits_warper = self._get_logits_warper(generation_config)

logits warper的使用方法与logits processor一样,都是用来修改概率的输出。关于他们的区别,暂时没有找到很好的解释,可以理解为warper控制着temperature、topk等与生成策略相关的参数。并且是在logits processor处理之后再进行处理的。

普通的beam search不会涉及这一部分,只有选择sample模式的beam search时,才会使用到logits warper。

需要记住的是,它的输入与processor一样,都是当前的序列(token_ids)与之前计算出的得分(scores),返回的结果是处理之后的得分,形状是(batch_size, config.vocab_size)

预设的warper包括:

warper 作用(仅供参考) 参考链接
TemperatureLogitsWarper 对score整体除以temperature做缩放 /
TopPLogitsWarper 概率小于topp的得分置为0 /
TopKLogitsWarper 只取topk的概率对应的词汇,其余的概率置为-inf /
TypicalLogitsWarper typical decoding arxiv
EpsilonLogitsWarper 将概率小于epsilon的token移除 arxiv
EtaLogitsWarper eta-sampling arxiv
LogitNormalization 在beam search进行的过程中做layernorm /

4.12 beam search

这一部分是beam search的核心流程,限于篇幅,其具体的执行生成过程将在本文的下篇中进行详细的介绍。

在这一部分中,首先创建了用于打分的BeamSearchScorer(具体作用将在下篇中进行介绍),然后根据num_beams对input_ids进行了扩展,最后执行beam search的核心方法beam_search,或beam sample对应的beam_sample方法。

            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=generation_config.num_beams,
                device=inputs_tensor.device,
                length_penalty=generation_config.length_penalty,
                do_early_stopping=generation_config.early_stopping,
                num_beam_hyps_to_keep=generation_config.num_return_sequences,
                max_length=generation_config.max_length,
            )
            # 12. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
            # 13. run beam search
            return self.beam_search(
                input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=generation_config.pad_token_id,
                eos_token_id=generation_config.eos_token_id,
                output_scores=generation_config.output_scores,
                return_dict_in_generate=generation_config.return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

在本文的下篇中,将介绍beam search的基本原理,transformers模块对于beam search的实现方法中,主要涉及的几个工具组件,beam search的生成与更新过程,以及beam sample对beam search的改进实现,感兴趣的朋友可以继续阅读。

Guess you like

Origin blog.csdn.net/weixin_44826203/article/details/129928897