# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
from .configuration_baichuan import BaichuanConfig
from .generation_utils import build_chat_input, TextIterStreamer
import math
from threading import Thread
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.activations import ACT2FN
from transformers.generation.utils import GenerationConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging, ContextManagers
import os
from contextlib import contextmanager
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
try:
from xformers import ops as xops
except ImportError:
xops = None
logger.warning(
"Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
)
def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
_get_interleave_power_of_2(closest_power_of_2)
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
def _fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
_future_mask = _future_mask.unsqueeze(0) + alibi
new_future_mask = _future_mask.to(tensor)
return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
def _gen_alibi_mask(tensor, n_head, max_pos):
slopes = torch.Tensor(_get_interleave(n_head))
position_point = torch.arange(max_pos) - max_pos + 1
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
alibi = alibi.view(n_head, 1, max_pos)
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
alibi_mask = alibi_mask.unsqueeze(0) + alibi
return alibi_mask
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, epsilon=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(hidden_size))
self.epsilon = epsilon
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
# convert into half-precision
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class MLP(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class BaichuanAttention(torch.nn.Module):
def __init__(self, config: BaichuanConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.model_max_length
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
)
self.W_pack = torch.nn.Linear(
self.hidden_size, 3 * self.hidden_size, bias=False
)
self.o_proj = torch.nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
proj = self.W_pack(hidden_states)
proj = (
proj.unflatten(-1, (3, self.hidden_size))
.unsqueeze(0)
.transpose(0, -2)
.squeeze(-2)
)
query_states = (
proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
)
key_states = (
proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
)
value_states = (
proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if xops is not None and self.training:
attn_weights = None
# query_states = query_states.transpose(1, 2)
# key_states = key_states.transpose(1, 2)
# value_states = value_states.transpose(1, 2)
# attn_output = xops.memory_efficient_attention(
# query_states, key_states, value_states, attn_bias=attention_mask
# )
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
attn_output = attn_output.transpose(1, 2)
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None:
if q_len == 1: # inference with cache
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
-
from .configuration_baichuan import BaichuanConfig
Importar clases
configuration_baichuan
en módulos bajo el paquete actualBaichuanConfig
. -
from .generation_utils import build_chat_input, TextIterStreamer
Importar y
generation_utils
en módulos bajo el paquete actual .build_chat_input
TextIterStreamer
-
import math
Importe la biblioteca de funciones matemáticas incorporada de Python.
-
from threading import Thread
Importe la clase en la biblioteca multiproceso
Thread
de Python . -
from typing import List, Optional, Tuple, Union
Importe la biblioteca de anotaciones de tipos de Python. Aquí se importan
List
,Optional
,Tuple
yUnion
. -
import torch
Importe el marco PyTorch.
-
from torch import nn
Importe la biblioteca de redes neuronales desde PyTorch.
-
from torch.nn import CrossEntropyLoss
Importe la función de pérdida de entropía cruzada de la biblioteca de redes neuronales de PyTorch.
-
from torch.nn import functional as F
Importe el módulo de funciones de la biblioteca de redes neuronales de PyTorch y asígnele un alias
F
. -
from transformers import PreTrainedModel, PretrainedConfig
Importar y desde transformers
la biblioteca .PreTrainedModel
PretrainedConfig
-
from transformers.activations import ACT2FN
transformers
Importe la tabla de mapeo de funciones de activación de la bibliotecaACT2FN
. -
from transformers.generation.utils import GenerationConfig
Importar desde módulo
transformers
en biblioteca .generation.utils
GenerationConfig
-
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
Importar y desde módulos
transformers
en la biblioteca .modeling_outputs
BaseModelOutputWithPast
CausalLMOutputWithPast
-
from transformers.utils import logging, ContextManagers
Importar y desde módulos
transformers
en la biblioteca .utils
logging
ContextManagers
-
import os
Importe el módulo de sistema operativo integrado de Python para manejar tareas relacionadas con el sistema operativo.
-
from contextlib import contextmanager
contextlib
Importado del módulo de Pythoncontextmanager
, se utiliza para crear administradores de contexto. -
from accelerate import init_empty_weights
Importar funciones desde
accelerate
bibliotecasinit_empty_weights
. -
logger = logging.get_logger(__name__)
transformers
Crea un objeto registrador utilizando la funcionalidad de registro proporcionada,__name__
que es el nombre del módulo actual.
19-26 try
.... except ImportError
bloque de código:
irCopiar código
试图从 `xformers` 库中导入 `ops` 模块,并为其取别名`xops`。如果导入失败(即没有正确安装`xformers`库),则将`xops`设为`None`,并发出一个警告消息。
28-39. _get_interleave(n)
:
irCopiar código
定义了一个辅助函数`_get_interleave`。这个函数有一个内嵌函数`_get_interleave_power_of_2`,用于计算并返回一个列表,该列表的长度与给定的数字`n`相同,元素为从开始值开始的等比数列。主函数`_get_interleave`根据`n`是否是2的整数次幂来调用内部函数,并返回一个列表。
41-43. _fill_with_neg_inf(t)
:
irCopiar código
定义了一个辅助函数`_fill_with_neg_inf`,它接受一个张量`t`,将其填充为负无穷大,并返回该张量。
45-50. _buffered_future_mask(tensor, maxpos, alibi, attn_heads)
:
irCopiar código
定义了一个辅助函数`_buffered_future_mask`,该函数用于生成一个未来掩码(常用于Transformer的自注意力机制)。
52-64. _gen_alibi_mask(tensor, n_head, max_pos)
:
irCopiar código
定义了一个辅助函数`_gen_alibi_mask`,用于生成一个"alibi"掩码。
66-77.Clase RMSNorm
:
Copiar código
定义了一个层归一化的变体:RMSNorm。这是一个神经网络模块,其核心功能是通过平方的均值进行归一化。
79-93.Clase MLP
:
irCopiar código
定义了一个多层感知机(MLP)类,这是一个神经网络模块,包含三个线性层和一个激活函数。在其前向传播中,输入`x`首先经过`gate_proj`层和激活函数,然后与`up_proj`的输出相乘,最后经过`down_proj`层。
Este código contiene principalmente algunas funciones auxiliares y dos módulos de red neuronal: RMSNorm
yMLP
. Estas funciones se pueden utilizar en modelos Transformer más grandes u otros modelos de redes neuronales.
Baichuan_atención:
Esta es una BaichuanAttention
clase denominada que define un módulo de mecanismo de autoatención, que es similar al mecanismo de atención en BERT, Transformer y otros modelos. Aquí hay una explicación línea por línea del código:
-
class BaichuanAttention(torch.nn.Module):
Defina unaBaichuanAttention
clase llamada , de la que heredatorch.nn.Module
, lo que significa que este es un módulo de red neuronal PyTorch. -
def __init__(self, config: BaichuanConfig):
Defina un constructor que acepte unBaichuanConfig
parámetro de tipo. -
super().__init__()
Llamar al constructor de la clase principal es una operación normal al definir su propia capa de red en PyTorch. -
self.config = config
Guarde la configuración entrante como una propiedad de la clase. -
self.hidden_size = config.hidden_size
Obtenga la propiedad de la configuraciónhidden_size
y guárdela como propiedad de la clase. -
self.num_heads = config.num_attention_heads
Obtenga el número de cabezas de atención de la configuración y guárdelo como una propiedad de la clase. -
self.head_dim = self.hidden_size // self.num_heads
Las dimensiones de cada cabeza de atención se calculan y guardan como atributos de la clase. -
self.max_position_embeddings = config.model_max_length
Obtenga la longitud máxima del modelo de la configuración y guárdela como propiedad de la clase.
9 a 12. if (self.head_dim * self.num_heads) != self.hidden_size:
Verifique si el tamaño de la capa oculta es divisible por el número de cabezas de atención.
13 a 15. self.W_pack = torch.nn.Linear(...)
Defina una capa lineal que transforme linealmente el estado oculto de la entrada para obtener la consulta, la clave y el valor.
16 a 18. self.o_proj = torch.nn.Linear(...)
Defina una capa lineal para la salida después del mecanismo de atención.
19 a 22. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
Defina una función auxiliar que remodele el tensor dado para adaptarlo al mecanismo de atención.
23 a 32. def forward(...):
Defina la función de propagación hacia adelante del módulo.
bsz, q_len, _ = hidden_states.size()
Obtenga el tamaño del lote, la longitud de la secuencia y el tamaño de la capa oculta del tensor de entrada.
34 a 38. Esta parte del código transforma linealmente el estado oculto de la entrada para obtener la consulta, la clave y el valor.
39 a 48. Esta parte del código da forma a la consulta, las claves y los valores para que se ajusten al cálculo de la atención.
49 a 54. Si se proporciona past_key_value
, se concatena con la clave y el valor actuales.
55-56. Si use_cache
es así True
, guarde la clave y el valor.
57-68 Determinar si está instalado xformers
y utilizar diferentes métodos de cálculo de atención según las condiciones.
69 a 81. Si no se utiliza xformers
, se utiliza el método de cálculo de autoatención del producto escalar convencional.
82 y 83. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
Ajustar la forma de la atención.
attn_output = self.o_proj(attn_output)
Pase la salida de atención a través de una capa lineal.
85 a 87. Dependiendo del output_attentions
valor de , el peso de atención puede devolverse o establecerse en None
.
return attn_output, attn_weights, past_key_value
Devuelve el resultado de la atención, el peso de la atención y los pares clave-valor anteriores.
Este módulo implementa la autoatención del producto escalable, que es un componente clave en la arquitectura Transformer .