Taking beam search as an example, explain the generate method in transformers in detail (Part 1)
- 1. The code location of generate
- 2. Overview of GenerationMixin
- 3. generate signature
- 4. generate process
-
- 4.1 Read and update generation config
- 4.2 Supplement the parameters that are not passed in
- 4.3 Defining model inputs
- 4.4 Defining other parameters of the model
- 4.5 Prepare input_ids for autoregressive models
- 4.6 Prepare the maximum length
- 4.7 Confirm Generation Mode
- 4.8 Create a logits processor
- 4.9 Create a stop rule
- 4.10 Enter the corresponding branch
- 4.11 Create logits warper
- 4.12 beam search
Compared with two years ago, the NLG task has been developed very effectively, and the transformers module is used more widely than ever before. In model reasoning and prediction, a core statement is model.generate()
, this article will introduce in detail how the generate method works. During the generation process, many generation strategies are included. This article will take the most commonly used beam search as an example, and introduce it in as much detail as possible within the scope of my ability.
Considering that the length may be relatively long, this article will be divided into two parts. The first part mainly introduces the overall structure of the generate method, and the second part will further introduce the code of the beam search part.
With the emergence of various LLMs, some changes have taken place in the code related to generate in transformers. The main differences are:
-
- The source code location of generate has changed;
-
- In the generate method, a generation_config parameter is used to manage various configurations related to generation, and the logic is optimized to make the logic clearer.
1. The code location of generate
In previous versions of transformers (transformers~=4.9), the generate method is located transformers.generation_utils.py
, which is GenerationMixin
a method of the class.
In the new version of transformers (transformers~=4.28), the generate method is transferred to transformers.generation.utils.py
, and it is still GenerationMixin
a class method.
For a pre-training model in the form of hf, it inherits PreTrainedModel
the class, and along this PreTrainedModel
class, you can see the higher-level inheritance logic, GenerationMixin
which is in it:
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
This is why AutoModel.from_pretrained()
a model that is instantiated can directly call generate
methods to do inference.
2. Overview of GenerationMixin
This part is written here as a cheat sheet. It is not recommended to read it directly, but to go back and view this part during the process of reading the code in Section 4.
GenerationMixin
An overview of all methods of the class is as follows:
method name | effect | Where it appears in this article |
---|---|---|
_validate_model_class | Check whether the model can be generated, and throw the corresponding exception | 4.1 |
_validate_model_kwargs | Check that the parameters in the generation config match the generation strategy | 4.1 |
_prepare_model_inputs | Prepare input for build process | 4.3 |
_maybe_initialize_input_ids_for_generation | When the inputs of the generation process are empty, use the bos token for initialization | 4.3 |
_prepare_attention_mask_for_generation | Prepare attention_mask for the generation process | 4.4 |
_prepare_encoder_decoder_kwargs_for_generation | Prepare encoder-related parameters for the generation process | 4.4 |
_prepare_decoder_input_ids_for_generation | Additional handling of input_ids for autoregressive models | 4.5 |
_get_decoder_start_token_id | Get the token id of the starting position of the decoder, this id may be different from the bos | 4.5 |
_get_logits_processor | Create a logits processor | 4.8 |
_get_stopping_criteria | Create a stop rule | 4.9 |
_get_logits_warper | Create logits warper | 4.11 |
_expand_inputs_for_generation | Expand input_ids according to num_beams | 4.12 |
prepare_inputs_for_generation | Preprocess the input to the model | Part 3.1 |
adjust_logits_during_generation | Make adjustments to the calculated logits during generation | Part 3.1 |
_update_model_kwargs_for_generation | Update the generation parameters according to the generation result of a step | Part 5.6 |
_reorder_cache | According to the beam_idx updated by step, rearrange the cached past_k_v | Part 5.6 |
3. generate signature
Before introducing the process, let's take a look at the signature of the generate method. In version 4.28, the signature is simplified as follows:
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
Compared with the previous version, the direct advantage of writing this way is that compared with the original super long signature, it reduces the parameters passed in, and integrates all parameters such as top_k
, top_p
, and so on , making the function look more simplified, and the Parameters can be read directly from the generation_config.json file under the model path, which provides convenience to users to a certain extent.num_beams
generation_config
The corresponding disadvantage is that many parameters are not explicitly exposed, which is not very convenient when viewing comments and customizing configurations. Optional parameters
need to be viewed in :GenerationConfig
from transformers.generation.configuration_utils import GenerationConfig
help(GenerationConfig)
( GenerationConfig
The parameters corresponding to various generation strategies are different, so we will not introduce them here. In the next part of this article, we will briefly introduce the parameters under the beam search strategy.)
The meaning and function of the parameters of the generate method are introduced as follows:
parameter name | type | meaning and function |
---|---|---|
inputs | torch.Tensor | The sequence id after tokenize, the model will calculate and generate logits based on this sequence. If it is empty, it defaults to the id corresponding to the bos token |
generation_config | GenerationConfig | The parameters corresponding to various generation strategies, if empty, will be read from the generation_config.json file under the model path, or obtained from model config |
logits_processor | LogitsProcessorList | Further process the logits calculated by the model, such as punishing the corresponding probability of the "repeater phenomenon" to avoid repeated model generation results |
stopping_criteria | StoppingCriteriaList | A tool for stopping the generation process, such as forcibly stopping when a certain length is reached, stopping when a certain generation time is reached, etc. |
prefix_allowed_tokens_fn | [int, torch.Tensor], List[int] | During the beam search process, the range of token ids allowed to be generated for each step |
synced_gpus | bool | Used when using DeepSpeed ZeRO |
streamer | BaseStreamer | Used when stream generate (that is, the effect of jumping out word by word) |
Among these inputs, logits_processor and stopping_criteria, will be the main means for users to manually intervene in the generation process.
4. generate process
In the transformers code of version 4.28, the comments of the generate process are written more clearly, so this article also uses the serial numbers in the code comments for division.
4.1 Read and update generation config
The general logic of this part is to deal with the case where the generation config is None, and to check whether there are any wrong parameters inconsistent with the generation strategy.
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
)
self.generation_config = new_generation_config
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
Neither of them _validate_model_class
and _validate_model_kwargs
the two methods are the key points, and they will not be introduced here.
4.2 Supplement the parameters that are not passed in
The parameters that need to be supplemented in this part include logits_processor
, stopping_criteria
, and generation_config
in pad_token_id
. The first two items are set as the default empty list; pad_token_id is not given, but if eos is given, use eos for padding.
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{
eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
4.3 Defining model inputs
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
Here we mainly need to pay attention to _prepare_model_inputs
this method. The core of this method, in one sentence, is that the sequence input_ids of the model input must be non-empty. If it is empty, it will be initialized with bos_token. The rest are for special cases of individual models:
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# 这一步似乎是起到一个校准的作用,防止某些encoder-decoder模型的主模型和encoder的输入名称不一致
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if (
self.config.is_encoder_decoder
and hasattr(self, "encoder")
and self.encoder.main_input_name != self.main_input_name
):
input_name = self.encoder.main_input_name
else:
input_name = self.main_input_name
model_kwargs = {
k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
# 确保inputs没有重复传入
# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {
inputs}` were passed alongside {
input_name} which is not allowed."
f"Make sure to either pass {
inputs} or {
input_name}=..."
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg
# 对于inputs_embeds这一输入参数:
# 如果是decoder-only模型,需要把'input_ids'这一参数放在inputs_kwarg中传入
# 如果是encoder-decoder模型,input_ids与inputs_embeds只能传入其一
# 3. In the presence of `inputs_embeds` for text models:
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
)
if not has_inputs_embeds_forwarding:
raise ValueError(
f"You passed `inputs_embeds` to `.generate()`, but the model class {
self.__class__.__name__} "
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs=model_kwargs
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
# 如果最后还是没有input_ids, 采用bos创建input_ids,可以简化理解为:
# torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs
4.4 Defining other parameters of the model
There is no need to pay special attention to this part, mainly some config settings to complete other parameters of the model, such as creating attention_mask, ensuring that the encoder-decoder model can return the 'ModelOutput' class, and so on.
# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)
# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
if (
generation_config.pad_token_id is not None
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name
)
4.5 Prepare input_ids for autoregressive models
The main difference between this step and 4.3 is that additional processing is performed for the AR model. If it is an encoder-decoder model, directly use the input_tensor created in 4.3 as input_ids.
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# 这里主要是针对decoder的开始位置id与bos id不同的情况
# 在这种情况下,decoder-only模型应当以配置中规定的decoder start id开始进行生成
# 此处可简单理解为:torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device,
)
# conditional generation for multi-modal models.
if "input_ids" in model_kwargs and model_input_name == "pixel_values":
input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
4.6 Prepare the maximum length
This part is to judge whether the length of input_id is too long according to the relevant configuration in config.
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({
generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={
generation_config.max_new_tokens}) and `max_length`(="
f"{
generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({
generation_config.min_length}) is larger than"
f" the maximum length ({
generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {
input_ids_string} is {
input_ids_seq_length}, but `max_length` is set to"
f" {
generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
4.7 Confirm Generation Mode
The beam search branch is directly selected here, and other modes will not be introduced, the same below.
There are two types of beam search, the basic one beam_gen_mode
and the advanced one beam_sample_gen_mode
. The former corresponds to the subsequent generation method of beam_search
, and the latter corresponds to the subsequent generation method of beam_sample
.
The main difference between the two is that the advanced model beam_sample_gen_mode
can set parameters such as temperature, top_k, and top_p to further control the generation. The setting method is introduced in section 4.11: logits warper . For the basic model beam_gen_mode
, there is no link to create logits warper.
# 7. determine generation mode
is_beam_sample_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
4.8 Create a logits processor
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
This link is more important because it involves the logits processor. These processors correct the calculated scores at each step during the generation process. In transformers
, several processors are preset, and users can also define their own processors (need to inherit the abstract class transformers.generation.logit_process.LogitsProcessor), and design their own logic to manually intervene in the generated process.
In beam search, the logits process is used as follows:
# 在def beam_sample中使用
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
Among them, input_ids is the sequence token id passed to the model by the current step corresponding to Tensor (batch_size, sequence_length), and next_token_scores is the score calculated by the model (that is, the probability distribution on vocab), taking log_softmax.
Here is a brief introduction to transformers
the processor preset in . Due to space limitations, all the source code is not posted, only its functions are summarized.
processor | effect | reference connection |
---|---|---|
MinLengthLogitsProcessor | By forcibly setting the probability of EOS to 0, ensure that the length of the generated result is greater than or equal to a minimum value | / |
MinNewTokensLengthLogitsProcessor | Similar to the previous one, but the part of the prompt is not included in the generated length | / |
RepetitionPenaltyLogitsProcessor | To prevent the "repeater" phenomenon, add penalties to repeated tokens, proposed by the pre-training model CTRL | arxiv |
EncoderRepetitionPenaltyLogitsProcessor | The difference from the previous one is that the generated result cannot be duplicated with the input id of the encoder, rather than with all currently given input ids | / |
NoRepeatNGramLogitsProcessor | Prevent repeated n-grams (n consecutive words or characters) from appearing in the generated text, the difference is that n consecutive words are limited | github |
EncoderNoRepeatNGramLogitsProcessor | n-gram可以在encoder的input ids中重复,不可以在decoder重复 | github |
NoBadWordsLogitsProcessor | 确保某些词永远不会被生成 | / |
PrefixConstrainedLogitsProcessor | 给定一个prefix_allow_func来限制符合哪些条件的token可以被生成 | arxiv |
HammingDiversityLogitsProcessor | 以Hamming距离为标准,确保生成的各个beam之前的区别足够大 | arxiv |
ForcedBOSTokenLogitsProcessor | 确保生成的第一个token是某个特定的token | / |
ForcedEOSTokenLogitsProcessor | 达到最大长度时,确保以某个特定的token作为结束 | / |
InfNanRemoveLogitsProcessor | 将计算出的得分中,nan替换为0,inf替换为可计算的最大值 | / |
SuppressTokensAtBeginLogitsProcessor | 在达到某个长度之后,将不再生成某些特定的词 | / |
SuppressTokensLogitsProcessor | 将某些特定词的概率设置为-inf,不生成这些词 | / |
ForceTokensLogitsProcessor | 建立一个映射表,把某个token强行映射成另一个token | / |
WhisperTimeStampLogitsProcessor | 强制模型生成时间戳(时间戳是一个特殊token,例如对话中,query=今天是周几?,answer=今天是[timestamp],这个[timestamp]后期会替换成对应的时间) | / |
4.9 创建停止规则
stopping_criteria与logits_processor是用户对生成过程进行干预的主要手段,相比logits_processor强行改变概率空间,stopping_criteria则是直接设定了终止生成的策略,理解起来也会相对容易一些。
# 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
预设的criteria总结如下:
criteria | 作用 |
---|---|
MaxLengthCriteria | 生成的序列达到设置的最大长度时,停止生成 |
MaxNewTokensCriteria | 生成的序列中,除去prompt的部分达到设置的最大长度时,停止生成 |
MaxTimeCriteria | 生成的耗时超过一定时间限制时,停止生成 |
如果是自定义criteria,应当继承抽象类transformers.generation.stopping_criteria.StoppingCriteria
。
4.10 进入相应的分支
这里直接选择进入beam search的分支。如前文所述,如果要控制temperature等超参数,则应该进入is_beam_sample_gen_mode这个分支。
4.11 创建logits warper
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits warper的使用方法与logits processor一样,都是用来修改概率的输出。关于他们的区别,暂时没有找到很好的解释,可以理解为warper控制着temperature、topk等与生成策略相关的参数。并且是在logits processor处理之后再进行处理的。
普通的beam search不会涉及这一部分,只有选择sample模式的beam search时,才会使用到logits warper。
需要记住的是,它的输入与processor一样,都是当前的序列(token_ids)与之前计算出的得分(scores),返回的结果是处理之后的得分,形状是(batch_size, config.vocab_size)
。
预设的warper包括:
warper | 作用(仅供参考) | 参考链接 |
---|---|---|
TemperatureLogitsWarper | 对score整体除以temperature做缩放 | / |
TopPLogitsWarper | 概率小于topp的得分置为0 | / |
TopKLogitsWarper | 只取topk的概率对应的词汇,其余的概率置为-inf | / |
TypicalLogitsWarper | typical decoding | arxiv |
EpsilonLogitsWarper | 将概率小于epsilon的token移除 | arxiv |
EtaLogitsWarper | eta-sampling | arxiv |
LogitNormalization | 在beam search进行的过程中做layernorm | / |
4.12 beam search
这一部分是beam search的核心流程,限于篇幅,其具体的执行生成过程将在本文的下篇中进行详细的介绍。
在这一部分中,首先创建了用于打分的BeamSearchScorer(具体作用将在下篇中进行介绍),然后根据num_beams对input_ids进行了扩展,最后执行beam search的核心方法beam_search
,或beam sample对应的beam_sample
方法。
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run beam search
return self.beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
在本文的下篇中,将介绍beam search的基本原理,transformers模块对于beam search的实现方法中,主要涉及的几个工具组件,beam search的生成与更新过程,以及beam sample对beam search的改进实现,感兴趣的朋友可以继续阅读。