[Procesamiento del lenguaje natural] [Modelo grande] Análisis del código fuente de la estructura del modelo BLOOM (versión independiente)

Análisis del código fuente de la estructura del modelo BLOOM (versión independiente)

Este artículo analiza el principio y la implementación de BLOOM basado en el código del modelo BLOOM en transformadores.

Blog relacionado
[Procesamiento del lenguaje natural] [Modelo grande] Análisis del código fuente de la estructura del modelo ChatGLM-6B (versión independiente)
[Procesamiento del lenguaje natural] [Modelo grande] Análisis del código fuente de la estructura del modelo BLOOM (versión independiente)
[Procesamiento del lenguaje natural] [Modelo grande] Ajuste de recursos extremadamente bajos de los métodos de modelos grandes Código de implementación LoRA y BLOOM-LORA
[Aprendizaje profundo] [Capacitación distribuida] Operaciones de comunicación colectiva y ejemplos de Pytorch
[Procesamiento del lenguaje natural] [Modelo grande] Chinchilla: modelo de lenguaje grande con capacitación y utilización informática óptimas
[Procesamiento del lenguaje natural] [Modelo grande] Modelo de lenguaje grande Prueba de herramienta de razonamiento BLOOM
[Procesamiento del lenguaje natural] [Modelo grande] GLM-130B: un modelo de lenguaje bilingüe previamente entrenado de código abierto
[Procesamiento del lenguaje natural] [ Modelo grande] 8 para una gran introducción a la multiplicación de matrices de bits de Transformer
[Procesamiento del lenguaje natural] [Modelo grande] BLOOM: Un modelo multilingüe con parámetros 176B y acceso abierto
[Procesamiento del lenguaje natural] [Serie ChatGPT] FLAN: El ajuste fino del modelo de lenguaje es un estudiante Zero-Shot
[Procesamiento del lenguaje natural] [Serie ChatGPT] ¿De dónde viene la inteligencia de ChatGPT?
[Procesamiento del lenguaje natural] [Serie ChatGPT] Aparición de modelos grandes

1. Máscara

1.1 Principio

BLOOM usa el Decodificador en el Transformador, que usa dos máscaras: (1) el relleno al construir el lote debe estar enmascarado; (2) en el Decodificador, el token actual solo puede ver el token a su izquierda, por lo que necesita atención estar enmascarado. La primera se llama Máscara de relleno y la segunda se llama Máscara causal.

​Máscara causal . Dada una longitud nnUna secuencia de n cuya matriz de puntuación de atención esA ∈ R n × n A\in\mathbb{R}^{n\times n}ARn × nA i , j A_{i,j}Ayo , jmostrar consulta qi \textbf{q}_iqyo和clave kj \textbf{k}_jkjpuntuación de atención. Sin embargo, la tarea de generación es de izquierda a derecha y no hay forma de ver los tokens a su derecha durante el proceso de generación. Para garantizar que " sólo el tobillo izquierdo sea visible " durante el entrenamiento , se puede lograr mediante la Máscara Causal. Específicamente, es para enmascarar la matriz de atención AA.El triángulo superior de A. La siguiente imagen esn = 5 n = 5norte=Máscara Causal para 5 casos.

inserte la descripción de la imagen aquí

​Máscara acolchada . Durante el entrenamiento del modelo, dado que las muestras de entrada tienen diferentes longitudes, se requiere relleno para igualar longitudes. Sin embargo, la parte de relleno debe ignorarse cuando el modelo se propaga hacia adelante y hacia atrás, por lo que se requiere la máscara de relleno. Padding Mask también es para la matriz de puntuación de atención AAA , por lo que su forma también debe ser la misma queAA.A es lo mismo. La siguiente figura es un ejemplo de una máscara de relleno con una longitud de 3 pero con un relleno de 5.

inserte la descripción de la imagen aquí

La máscara completa de la matriz de puntuación de atención El proceso se muestra en la siguiente figura.

inserte la descripción de la imagen aquí

1.2 Código

Máscara causal

def _make_causal_mask(
    input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
    """
    input_ids_shape:(batch_size, seq_length)
    """
    batch_size, target_length = input_ids_shape
    mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
    # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
    seq_ids = torch.arange(target_length, device=device)
    mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]

    if past_key_values_length > 0:
        mask[:, :past_key_values_length] = False

    expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
    return expanded_mask

Máscara acolchada

def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
    """
    mask: (batch_size, seq_length)
    """
    batch_size, src_length = mask.shape
    tgt_length = tgt_length if tgt_length is not None else src_length

    expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
    return expanded_mask.expand(batch_size, 1, tgt_length, src_length)

2. Función de activación

La función de activación de Bloom usa GELU \text{GELU}GELU ,GELU \text{GELU}GELU se puede aproximar como
GELU ( x ) ≈ 0.5 x ( 1 + tanh ⁡ ( 2 π ( x + 0.044715 x 3 ) ) ) \text{GELU}(x)\approx 0.5x(1+\tanh(\ sqrt{ \frac{2}{\pi}}(x+0.044715x^3)))GELÚ ( x )0,5 x ( 1 .)+a pescado (Pi2 ( x+0.044715 x3 )))

def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
    """
    GELLU前向
    """
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))


def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    GELU后向
    """
    x = x[0]  # x is a tuple of 1 element, needs to unpack it first
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
    return ff * g

class GeLUFunction(torch.autograd.Function):
    """
    完整的GeLU激活函数
    """
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return bloom_gelu_forward(input)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        input = ctx.saved_tensors
        tmp = bloom_gelu_back(grad_output, input)
        return tmp
    
class BloomGelu(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.training:
            return GeLUFunction.apply(x)
        else:
            # 非训练时,只执行前向传播
            return bloom_gelu_forward(x)

3. capa MLP

MLP ( X , R ) = abandono ( GELU ( XW 1 ) W 2 ) + R ; X es la entrada, R es el residual\text{MLP}(X, R) = \text{dropout}(\text{GELU } (XW_1)W_2)+R; \quad X es la entrada, R es el residualMLP ( X ,r )=abandono ( GELU ( X W1) W2)+R ;X es la entrada, R es el residual

class BloomMLP(nn.Module):
    def __init__(self, config: BloomConfig):
        super().__init__()
        hidden_size = config.hidden_size
        
        # 预训练时的张量并行度
        self.pretraining_tp = config.pretraining_tp
        self.slow_but_exact = config.slow_but_exact
        # h至4h的全链接层
        self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
        self.gelu_impl = BloomGelu()
        # 4h到h的全链接层
        self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
        # dorpout
        self.hidden_dropout = config.hidden_dropout

    def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
        """
        hidden_states: (batch_size, seq_length, hidden_size)
        residual与hidden_states形状相同
        """
        # 全链接层+GLUE
        hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
        
        # 将hidden_states从4h在映射会h
        # intermediate_output的形状同hidden_states
        if self.pretraining_tp > 1 and self.slow_but_exact:
            # 判断预训练时是否使用了张量并行,且要采用慢且精确的前向传播
            intermediate_output = torch.zeros_like(residual)
            slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
            for i in range(self.pretraining_tp):
                intermediate_output = intermediate_output + F.linear(
                    hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
                    self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
                )
        else:
            intermediate_output = self.dense_4h_to_h(hidden_states)
        # 对intermediate_output执行dropout后,加上残差residual
        output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)

        return output

4. ALiBi: inyecta información de ubicación

1. Principio

BLOOM utiliza ALiBi para inyectar información de ubicación en el modelo. Dada una longitud LLLa secuencia de entrada de L , luego el iithde cada cabeza de atención.i个consultaqi ∈ R 1 × d ( 1 ≤ i ≤ L ) \textbf{q}_i\in\mathbb{R}^{1\times d}(1\leq i\leq L)qyoR1 × re (1iL ) para el primeroiii个 claveK ∈ R i × d \textbf{K}\in\mathbb{R}^{i\times d}kRLa puntuación de atención de i × d
es softmax ( qi K ⊤ ) \text{softmax}(\textbf{q}_i\textbf{K}^\top)softmax ( qyok )
Cuando se utiliza ALiBi, no es necesario agregar incrustaciones de posición a la red. Solo es necesario agregar un sesgo estático al producto escalar de la clave de consulta.
softmax ( qi K ⊤ + m ⋅ [ − ( i − 1 ) , … , − 2 , − 1 , 0 ] ) \text{softmax}(\textbf{q}_i\textbf{K}^\top+m\ cdot[-(i-1),\puntos,-2,-1,0])softmax ( qyok+metro[ - ( yo1 ) ,,2 ,1 ,0 ])
dondemmm es la pendiente (pendiente) relacionada con la cabeza de atención, que es el hiperparámetro;[ − ( i − 1 ) , … , − 2 , − 1 , 0 ] [-(i-1),\dots,-2, -1,0][ - ( yo1 ) ,,2 ,1 ,0 ] es en realidadqi \textbf{q}_iqyoLa distancia relativa desde cada tecla.

​ Para 8 cabezales de atención, mmm es una secuencia de razones:1 2 1 , 1 2 2 , … , 1 2 8 \frac{1}{2^1},\frac{1}{2^2},\dots,\frac{1}{ 2^8}211,221,,281. Para un modelo con 16 cabezales de atención, mmm es una secuencia proporcional:1 2 0.5 , 1 2 1 , 1 2 1.5 , … , 1 8 \frac{1}{2^{0.5}},\frac{1}{2^1},\frac{ 1 }{2^{1.5}},\puntos,\frac{1}{8}20,51,211,21.51,,81

2. Date cuenta

def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
    batch_size, seq_length = attention_mask.shape
    # closet_power_of_2是与num_head接近的2的次方
    # 例如:num_heads为5/6/7时,closest_power_of_2为4
    closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
    # 计算斜率
    base = torch.tensor(
        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
    )
    powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
    slopes = torch.pow(base, powers)
    
    # 注意力头数量不是2的次方
    if closest_power_of_2 != num_heads:
        extra_base = torch.tensor(
            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
        )
        num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
        extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
        
    # 相对距离
    arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
    # alibi会与query和key的乘积相加
    # alibi的形状为[batch_size, num_heads, query_length, key_length]
    alibi = slopes[..., None] * arange_tensor
    return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)

​ Durante la implementación, para evitar la operación de división en el cálculo de la pendiente, la pendiente se calcula de la siguiente manera:
base = 2 − ( 2 − ( log ⁡ 2 n − 3 ) ) = 1 2 8 / n = 1 2 8 n potencia = [ 1 , … , n ] \begin{align} &\text{base} = 2^{-(2^{-(\log_2 n-3)})}=\frac{1}{2^{8 /n }}=\frac{1}{\sqrt[n]{2^8}} \\ &\text{potencia} = [1,\dots, n] \\ \end{align}base=2( 2( l o g2norte - 3 ) )=28/ norte1=norte28 1fuerza=[ 1 ,,norte ]

El valor de retorno de la función es m ⋅ [ − ( i − 1 ) , … , − 2 , − 1 , 0 ] m\cdot[-(i-1),\dots,-2,-1,0]metro[ - ( yo1 ) ,,2 ,1 ,0 ]

5. Capa de atención de múltiples cabezas

1. Principio

La atención multicabezal BLOOM consiste en agregar ALiBi a la atención multicabezal estándar.

Cambiar :
Q = W q XK = W k XV = W v X Atención ( Q , K , V , A ) = softmax ( QKT dk + A ) V \begin {align} Q &= W_q X \\ K &= W_k X \\ V &= W_v X \\ \text{Atención}(Q,K,V,A) &= \text{softmax}(\frac{QK^T}{\sqrt{d_k}} +A)V \end{alinear}qkVAtención ( Q ,k ,V ,Un )=W.qX=W.kX=W.vX=softmax (dk q kt+A ) V
Entre ellos, XXX se importa,W q , W k , W v W_q,W_k,W_vW.q,W.k,W.vSon la matriz de proyección de consulta, clave y valor respectivamente, AAA es la matriz de sesgo de ALiBi.

Atención múltiple :

La atención de múltiples cabezas es la unión de los resultados de la atención de múltiples cabezas.
cabeza i = Atención ( Q i , K i , V i , A i ) MultiHead ( Q , K , V , A ) = Concat ( cabeza 1 , … , cabeza h ) W o \begin{align} \text{head} _i&=\text{Atención}(Q_i,K_i,V_i,A_i) \\ \text{MultiHead}(Q,K,V,A)&=\text{Concat}(\text{head}_1,\dots, \text{head}_h)W_o \end{align}cabezayoCabezal múltiple ( Q ,k ,V ,Un )=Atención ( Qyo,kyo,Vyo,Ayo)=Concat ( cabeza1,,cabezah) Wo

2. Date cuenta

class BloomAttention(nn.Module):
    def __init__(self, config: BloomConfig):
        super().__init__()
        # 预训练时,张量并行相关的参数(这里不需要关注)
        self.pretraining_tp = config.pretraining_tp
        self.slow_but_exact = config.slow_but_exact
        
        self.hidden_size = config.hidden_size
        self.num_heads = config.n_head
        self.head_dim = self.hidden_size // self.num_heads
        self.split_size = self.hidden_size
        self.hidden_dropout = config.hidden_dropout

        if self.head_dim * self.num_heads != self.hidden_size:
            raise ValueError(
                f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {
      
      self.hidden_size} and `num_heads`:"
                f" {
      
      self.num_heads})."
            )

        # Layer-wise attention scaling
        self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
        self.beta = 1.0
        
        # query、key、value的投影层
        self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
        # 输出投影层
        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
        self.attention_dropout = nn.Dropout(config.attention_dropout)
        
	def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        fused_qkv: [batch_size, seq_length, 3*hidden_size]
        """
        batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
        # 1. 将Q、K、V拆分出来;2. 拆分出多个头
        fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
        return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
    
    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        # 目标:batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
        batch_size_and_num_heads, seq_length, _ = x.shape
        batch_size = batch_size_and_num_heads // self.num_heads
        # 将batch_size拆分出来:batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
        x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
        # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
        x = x.permute(0, 2, 1, 3)
        # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
        return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        alibi: torch.Tensor,
        attention_mask: torch.Tensor,
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        output_attentions: bool = False,
    ):
        # [batch_size, seq_length, 3 x hidden_size]
        # 一次性得到投影的Q、K、V,减少执行矩阵乘法的次数
        fused_qkv = self.query_key_value(hidden_states)
        
        # 多头拆分
        # 3 x [batch_size, seq_length, num_heads, head_dim]
        (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
        batch_size, q_length, _, _ = query_layer.shape
        
        query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
        key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
        value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
        
        # 处理传入的key和value(忽略)
        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key, key_layer), dim=2)
            value_layer = torch.cat((past_value, value_layer), dim=1)

        _, _, kv_length = key_layer.shape
        
        # 忽略
        if use_cache is True:
            present = (key_layer, value_layer)
        else:
            present = None
            
        # [batch_size * num_heads, q_length, kv_length]
        # inv_norm_factor*(query_layer*key_layer) + beta*alibi
        matmul_result = alibi.baddbmm(
            batch1=query_layer,
            batch2=key_layer,
            beta=self.beta,
            alpha=self.inv_norm_factor,
        )
        
        # [batch_size, num_heads, q_length, kv_length]
        attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
        
        # 若输入类型是float16,则将注意力分数转换为float32
        # 注意力分数的精度会显著影响模型的效果
        input_dtype = attention_scores.dtype
        if input_dtype == torch.float16:
            attention_scores = attention_scores.to(torch.float)
        
        # mask
        attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
        # softmax
        attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)

        # [batch_size, num_heads, q_length, kv_length]
        # dropout
        attention_probs = self.attention_dropout(attention_probs)
        
        # 若传入注意力头的mask
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        # attention_probs_reshaped:[batch_size x num_heads, q_length, kv_length]
        attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)

        # context_layer: [batch_size * num_heads, q_length, head_dim]
        # 乘以value
        context_layer = torch.bmm(attention_probs_reshaped, value_layer)
        
        # context_layer: batch_size, seq_length, num_heads * head_dim
        # 合并多头
        context_layer = self._merge_heads(context_layer)

        # 输出投影
        if self.pretraining_tp > 1 and self.slow_but_exact:
            slices = self.hidden_size / self.pretraining_tp
            output_tensor = torch.zeros_like(context_layer)
            for i in range(self.pretraining_tp):
                output_tensor = output_tensor + F.linear(
                    context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
                    self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
                )
        else:
            output_tensor = self.dense(context_layer)
            
        # dropout+残差
        output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)

        outputs = (output_tensor, present)
        if output_attentions:
            outputs += (attention_probs,)

        return outputs

6. Bloque de floración

inserte la descripción de la imagen aquí

class BloomBlock(nn.Module):
    def __init__(self, config: BloomConfig):
        super().__init__()
        hidden_size = config.hidden_size

        self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.num_heads = config.n_head
        self.self_attention = BloomAttention(config)
        self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = BloomMLP(config)

        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
        self.hidden_dropout = config.hidden_dropout

    def forward(
        self,
        hidden_states: torch.Tensor,
        alibi: torch.Tensor,
        attention_mask: torch.Tensor,
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        output_attentions: bool = False,
    ):
        # hidden_states: [batch_size, seq_length, hidden_size]
        # 先对hidden_states进行Layer Norm
        layernorm_output = self.input_layernorm(hidden_states)

        # 残差
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        # Self attention.
        attn_outputs = self.self_attention(
            layernorm_output,
            residual,
            layer_past=layer_past,
            attention_mask=attention_mask,
            alibi=alibi,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )

        attention_output = attn_outputs[0]

        outputs = attn_outputs[1:]

        layernorm_output = self.post_attention_layernorm(attention_output)

        # Get residual
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = attention_output

        # MLP.
        output = self.mlp(layernorm_output, residual)

        if use_cache:
            outputs = (output,) + outputs
        else:
            outputs = (output,) + outputs[1:]

        return outputs  # hidden_states, present, attentions

7. Modelo Bloom

inserte la descripción de la imagen aquí

class BloomModel(BloomPreTrainedModel):
    def __init__(self, config: BloomConfig):
        super().__init__(config)
        self.embed_dim = config.hidden_size
        self.num_heads = config.n_head
        # Embedding + LN Embedding
        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
        self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
        # BloomBlocks
        self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
        # 最终Layer Norm
        self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
        self.gradient_checkpointing = False
        self.post_init()

    def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
        """
        封装build_alibi_tensor函数
        """
        return build_alibi_tensor(attention_mask, num_heads, dtype)

    def get_input_embeddings(self):
        return self.word_embeddings

    def _prepare_attn_mask(
        self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
    ) -> torch.BoolTensor:
        # 构建注意力分数的mask句子,见文章第一节的掩码(Mask)部分
        combined_attention_mask = None
        device = attention_mask.device
        _, src_length = input_shape

        if src_length > 1:
            # 构建causal mask
            combined_attention_mask = _make_causal_mask(
                input_shape, device=device, past_key_values_length=past_key_values_length
            )

        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
        # 构建padding mask
        expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
        # 两种mask合并
        combined_attention_mask = (
            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
        )
        return combined_attention_mask

    def set_input_embeddings(self, new_embeddings: torch.Tensor):
        self.word_embeddings = new_embeddings

    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPastAndCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **deprecated_arguments,
    ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
        ### (开始)一些输入输出和参数设置,可以忽略
        if deprecated_arguments.pop("position_ids", False) is not False:
            # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
            warnings.warn(
                "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
                " passing `position_ids`.",
                FutureWarning,
            )
        if len(deprecated_arguments) > 0:
            raise ValueError(f"Got unexpected arguments: {
      
      deprecated_arguments}")

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if past_key_values is None:
            past_key_values = tuple([None] * len(self.h))
        ### (结束)一些输入输出和参数设置,可以忽略

        # 准备head mask,1.0表示保留注意力头
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
            
        # 在embedding后添加了layernorm
        hidden_states = self.word_embeddings_layernorm(inputs_embeds)

        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None
        
        ### (开始) gradient checkpointing和past_key_values处理,忽略
        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # Compute alibi tensor: check build_alibi_tensor documentation
        seq_length_with_past = seq_length
        past_key_values_length = 0
        if past_key_values[0] is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length
        ### (结束) gradient checkpointing和past_key_values处理,忽略
        
        # 构建注意力分数掩码
        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
        else:
            attention_mask = attention_mask.to(hidden_states.device)

        alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)

        causal_mask = self._prepare_attn_mask(
            attention_mask,
            input_shape=(batch_size, seq_length),
            past_key_values_length=past_key_values_length,
        )
        
        # BloomBlock前向传播
        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)

                    return custom_forward

                outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    alibi,
                    causal_mask,
                    layer_past,
                    head_mask[i],
                )
            else:
                outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=causal_mask,
                    head_mask=head_mask[i],
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    alibi=alibi,
                )

            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)

        # Add last hidden state
        hidden_states = self.ln_f(hidden_states)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )

Supongo que te gusta

Origin blog.csdn.net/bqw18744018044/article/details/130694055
Recomendado
Clasificación