Análisis del código fuente de Baichuan2: Baichuan2-13B-Chat/modelling_baichuan.py

# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.

from .configuration_baichuan import BaichuanConfig
from .generation_utils import build_chat_input, TextIterStreamer

import math
from threading import Thread
from typing import List, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.activations import ACT2FN
from transformers.generation.utils import GenerationConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging, ContextManagers

import os
from contextlib import contextmanager
from accelerate import init_empty_weights

logger = logging.get_logger(__name__)

try:
    from xformers import ops as xops
except ImportError:
    xops = None
    logger.warning(
        "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
    )


def _get_interleave(n):
    def _get_interleave_power_of_2(n):
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return _get_interleave_power_of_2(n)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(n))
        return (
            _get_interleave_power_of_2(closest_power_of_2)
            + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
        )


def _fill_with_neg_inf(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(float("-inf")).type_as(t)


def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
    _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
    _future_mask = _future_mask.unsqueeze(0) + alibi
    new_future_mask = _future_mask.to(tensor)
    return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]


def _gen_alibi_mask(tensor, n_head, max_pos):
    slopes = torch.Tensor(_get_interleave(n_head))
    position_point = torch.arange(max_pos) - max_pos + 1
    position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
    diag = torch.diag(position_point[0])
    position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
    alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
    alibi = alibi.view(n_head, 1, max_pos)
    alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
    alibi_mask = alibi_mask.unsqueeze(0) + alibi
    return alibi_mask


class RMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, epsilon=1e-6):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(hidden_size))
        self.epsilon = epsilon

    def forward(self, hidden_states):
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)

        # convert into half-precision
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

        return self.weight * hidden_states


class MLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
    ):
        super().__init__()
        self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
        self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class BaichuanAttention(torch.nn.Module):
    def __init__(self, config: BaichuanConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.max_position_embeddings = config.model_max_length

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
            )
        self.W_pack = torch.nn.Linear(
            self.hidden_size, 3 * self.hidden_size, bias=False
        )
        self.o_proj = torch.nn.Linear(
            self.num_heads * self.head_dim, self.hidden_size, bias=False
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return (
            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
            .contiguous()
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        proj = self.W_pack(hidden_states)
        proj = (
            proj.unflatten(-1, (3, self.hidden_size))
            .unsqueeze(0)
            .transpose(0, -2)
            .squeeze(-2)
        )
        query_states = (
            proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )
        key_states = (
            proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )
        value_states = (
            proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None
        if xops is not None and self.training:
            attn_weights = None
            # query_states = query_states.transpose(1, 2)
            # key_states = key_states.transpose(1, 2)
            # value_states = value_states.transpose(1, 2)
            # attn_output = xops.memory_efficient_attention(
            #     query_states, key_states, value_states, attn_bias=attention_mask
            # )
            with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
                attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
            attn_output = attn_output.transpose(1, 2)
        else:
            attn_weights = torch.matmul(
                query_states, key_states.transpose(2, 3)
            ) / math.sqrt(self.head_dim)

            if attention_mask is not None:
                if q_len == 1:  # inference with cache
                    if len(attention_mask.size()) == 4:
                        attention_mask = attention_mask[:, :, -1:, :]
                    else:
                        attention_mask = attention_mask[:, -1:, :]
                attn_weights = attn_weights + attention_mask
                attn_weights = torch.max(
                    attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
                )

            attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
            attn_output = torch.matmul(attn_weights, value_states)

            attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value
  1. from .configuration_baichuan import BaichuanConfig

    Importar clases configuration_baichuanen módulos bajo el paquete actual BaichuanConfig.

  2. from .generation_utils import build_chat_input, TextIterStreamer

    Importar y generation_utilsen módulos bajo el paquete actual .build_chat_inputTextIterStreamer

  3. import math

    Importe la biblioteca de funciones matemáticas incorporada de Python.

  4. from threading import Thread

    Importe la clase en la biblioteca multiprocesoThread de Python .

  5. from typing import List, Optional, Tuple, Union

    Importe la biblioteca de anotaciones de tipos de Python. Aquí se importan List, Optional, Tupley Union.

  6. import torch

    Importe el marco PyTorch.

  7. from torch import nn

    Importe la biblioteca de redes neuronales desde PyTorch.

  8. from torch.nn import CrossEntropyLoss

    Importe la función de pérdida de entropía cruzada de la biblioteca de redes neuronales de PyTorch.

  9. from torch.nn import functional as F

    Importe el módulo de funciones de la biblioteca de redes neuronales de PyTorch y asígnele un alias F.

  10. from transformers import PreTrainedModel, PretrainedConfig

Importar y desde transformersla biblioteca .PreTrainedModelPretrainedConfig

  1. from transformers.activations import ACT2FN

    transformersImporte la tabla de mapeo de funciones de activación de la biblioteca ACT2FN.

  2. from transformers.generation.utils import GenerationConfig

    Importar desde módulo transformersen biblioteca .generation.utilsGenerationConfig

  3. from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast

    Importar y desde módulostransformers en la biblioteca .modeling_outputsBaseModelOutputWithPastCausalLMOutputWithPast

  4. from transformers.utils import logging, ContextManagers

    Importar y desde módulos transformersen la biblioteca .utilsloggingContextManagers

  5. import os

    Importe el módulo de sistema operativo integrado de Python para manejar tareas relacionadas con el sistema operativo.

  6. from contextlib import contextmanager

    contextlibImportado del módulo de Python contextmanager, se utiliza para crear administradores de contexto.

  7. from accelerate import init_empty_weights

    Importar funciones desde acceleratebibliotecas init_empty_weights.

  8. logger = logging.get_logger(__name__)

    transformersCrea un objeto registrador utilizando la funcionalidad de registro proporcionada, __name__que es el nombre del módulo actual.

19-26 try.... except ImportErrorbloque de código:

 
 

irCopiar código

试图从 `xformers` 库中导入 `ops` 模块,并为其取别名`xops`。如果导入失败(即没有正确安装`xformers`库),则将`xops`设为`None`,并发出一个警告消息。

28-39. _get_interleave(n):

 
 

irCopiar código

定义了一个辅助函数`_get_interleave`。这个函数有一个内嵌函数`_get_interleave_power_of_2`,用于计算并返回一个列表,该列表的长度与给定的数字`n`相同,元素为从开始值开始的等比数列。主函数`_get_interleave`根据`n`是否是2的整数次幂来调用内部函数,并返回一个列表。

41-43. _fill_with_neg_inf(t):

 
 

irCopiar código

定义了一个辅助函数`_fill_with_neg_inf`,它接受一个张量`t`,将其填充为负无穷大,并返回该张量。

45-50. _buffered_future_mask(tensor, maxpos, alibi, attn_heads):

 
 

irCopiar código

定义了一个辅助函数`_buffered_future_mask`,该函数用于生成一个未来掩码(常用于Transformer的自注意力机制)。

52-64. _gen_alibi_mask(tensor, n_head, max_pos):

 
 

irCopiar código

定义了一个辅助函数`_gen_alibi_mask`,用于生成一个"alibi"掩码。

66-77.Clase RMSNorm:

 
 

Copiar código

定义了一个层归一化的变体:RMSNorm。这是一个神经网络模块,其核心功能是通过平方的均值进行归一化。

79-93.Clase MLP:

 
 

irCopiar código

定义了一个多层感知机(MLP)类,这是一个神经网络模块,包含三个线性层和一个激活函数。在其前向传播中,输入`x`首先经过`gate_proj`层和激活函数,然后与`up_proj`的输出相乘,最后经过`down_proj`层。

Este código contiene principalmente algunas funciones auxiliares y dos módulos de red neuronal: RMSNormyMLP . Estas funciones se pueden utilizar en modelos Transformer más grandes u otros modelos de redes neuronales.

Baichuan_atención:
 

Esta es una BaichuanAttentionclase denominada que define un módulo de mecanismo de autoatención, que es similar al mecanismo de atención en BERT, Transformer y otros modelos. Aquí hay una explicación línea por línea del código:

  1. class BaichuanAttention(torch.nn.Module):Defina una BaichuanAttentionclase llamada , de la que hereda torch.nn.Module, lo que significa que este es un módulo de red neuronal PyTorch.

  2. def __init__(self, config: BaichuanConfig):Defina un constructor que acepte un BaichuanConfigparámetro de tipo.

  3. super().__init__()Llamar al constructor de la clase principal es una operación normal al definir su propia capa de red en PyTorch.

  4. self.config = configGuarde la configuración entrante como una propiedad de la clase.

  5. self.hidden_size = config.hidden_sizeObtenga la propiedad de la configuración hidden_sizey guárdela como propiedad de la clase.

  6. self.num_heads = config.num_attention_headsObtenga el número de cabezas de atención de la configuración y guárdelo como una propiedad de la clase.

  7. self.head_dim = self.hidden_size // self.num_headsLas dimensiones de cada cabeza de atención se calculan y guardan como atributos de la clase.

  8. self.max_position_embeddings = config.model_max_lengthObtenga la longitud máxima del modelo de la configuración y guárdela como propiedad de la clase.

9 a 12. if (self.head_dim * self.num_heads) != self.hidden_size:Verifique si el tamaño de la capa oculta es divisible por el número de cabezas de atención.

13 a 15. self.W_pack = torch.nn.Linear(...)Defina una capa lineal que transforme linealmente el estado oculto de la entrada para obtener la consulta, la clave y el valor.

16 a 18. self.o_proj = torch.nn.Linear(...)Defina una capa lineal para la salida después del mecanismo de atención.

19 a 22. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):Defina una función auxiliar que remodele el tensor dado para adaptarlo al mecanismo de atención.

23 a 32. def forward(...):Defina la función de propagación hacia adelante del módulo.

  1. bsz, q_len, _ = hidden_states.size()Obtenga el tamaño del lote, la longitud de la secuencia y el tamaño de la capa oculta del tensor de entrada.

34 a 38. Esta parte del código transforma linealmente el estado oculto de la entrada para obtener la consulta, la clave y el valor.

39 a 48. Esta parte del código da forma a la consulta, las claves y los valores para que se ajusten al cálculo de la atención.

49 a 54. Si se proporciona past_key_value, se concatena con la clave y el valor actuales.

55-56. Si use_cachees así True, guarde la clave y el valor.

57-68 Determinar si está instalado xformersy utilizar diferentes métodos de cálculo de atención según las condiciones.

69 a 81. Si no se utiliza xformers, se utiliza el método de cálculo de autoatención del producto escalar convencional.

82 y 83. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)Ajustar la forma de la atención.

  1. attn_output = self.o_proj(attn_output)Pase la salida de atención a través de una capa lineal.

85 a 87. Dependiendo del output_attentionsvalor de , el peso de atención puede devolverse o establecerse en None.

  1. return attn_output, attn_weights, past_key_valueDevuelve el resultado de la atención, el peso de la atención y los pares clave-valor anteriores.

Este módulo implementa la autoatención del producto escalable, que es un componente clave en la arquitectura Transformer .

Supongo que te gusta

Origin blog.csdn.net/sinat_37574187/article/details/133090157
Recomendado
Clasificación