The logic of Bart's training and prediction is different, so how do their logics differ?
during training
- Parallel
Because it is necessary to obtain subsequent sequences based on previous sequences, parallel training can be adopted. For example, if our target sequence length is N, then parallel training is to create an N*N matrix, but this matrix is a lower triangular matrix, so that the problem of generating the current word and seeing subsequent words can be solved.
It is obtained through the following methoddecoder_attention_mask
,
let’s talk about the following parameters in detail:
input_shape: dimension information is [bsz,tgt_len],
inputs_embeds
past_key_values_length
The program will enter the first if logic and enter _make_causal_mask()
The full() function will get a matrix specifying initialization values.
The arange() function will initialize a value from 0 to mask.size(-1), which is the following sequence in this code:
after another comparison operation, the mask matrix is obtained:
after obtaining the mask matrix, it needs to be based on bsz to expand, the expanded function is expand. The combined_attention_mask obtained in the end, why does this combined_attention_mask need past_key_value_length?
After getting the attention_mask, it is sent to the first layer of self-attention of the decoder, and the output vector of each position is calculated:
During inference
, because inference is auto-regressive, it cannot be parallelized.
bart's vocabulary
encoder_attention_mask
What is the function of passed in the decoder ?
As shown in the figure below,
let's first see how the subsequent code uses this variable.
In the decoder code, the encoder_attention_mask is changed to [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
this dimension.
This is for cross-attention.
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
#这个是self attention 计算得到的结果
hidden_states=hidden_states,
# 这个是encoder 的结果,tensor 维度是 (bsz,seq_len, hidden_dim),这个值是用于后面计算key value 的
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)