[PNL] Interpretación de LLaMA y LLamMA2

Resumen

Meta propuso recientemente los parámetros del modelo LLaMA (modelo de lenguaje básico abierto y eficiente) que incluyen múltiples versiones de 7B a 65B. En particular, LLaMA-13B supera a GPT-3 y es más de 10 veces más pequeño, y LLaMA-65B es competitivo con Chinchilla-70B y PaLM-540B.

I. Introducción

En general, cuanto más grande es el modelo, mejores son los resultados. Sin embargo, se ha señalado en [1] que cuando se da un presupuesto para el cálculo, el mejor rendimiento no es el modelo más grande, sino un modelo pequeño con más datos para el entrenamiento. Para un presupuesto de cálculo dado, las leyes de escala pueden calcular cómo elegir el tamaño del volumen de datos y el tamaño del modelo. Sin embargo, esto ignora el presupuesto de inferencia, que es fundamental para la inferencia del modelo. Cuando se da un objetivo de rendimiento del modelo, el mejor modelo no es el modelo más rápido para el entrenamiento, sino el modelo más rápido para el razonamiento. Aunque en este caso sería más económico entrenar un modelo más grande.

En [2] se recomienda que entrenar un modelo 10B requiera tokens de 200B, pero el experimento en este documento encontró que después de entrenar un modelo 7B con tokens de 1T, el rendimiento sigue aumentando. El objetivo de este artículo es brindar una serie de LLM con el mejor rendimiento posible mediante la capacitación en datos a gran escala.

2. Datos previos al entrenamiento

2.1 Conjunto de datos

El corpus de capacitación es un corpus mixto de código abierto, la proporción de chino es muy baja y casi no se admite chino. Las proporciones detalladas son: CommonCrawl 67 %, C4 15 %, GitHub 4,5 %, Wikipedia 4,5 %, Books 4,5 %, ArXiv 2,5 %, Stack Exchange 2 %.

Hay tokens de 1.4T en total, y la mayoría de los datos de entrenamiento solo se usan una vez, excepto Wikipedia y Books, que usan alrededor de 2 épocas.

2.2 Tokenizador

Utilizando el algoritmo de codificación de pares de bytes (BPE), se utiliza la implementación de Sentence-Piece. Todos los números se dividen en dígitos individuales y todos los caracteres UTF-8 desconocidos se reducen a bytes para su descomposición. Por lo tanto, LLaMA puede construir muchos caracteres que no están en el vocabulario por medio de bytes, por lo que también tiene una mejor capacidad multilingüe.

3. Mejora de la estructura de la red

optimizador

El modelo del artículo se entrena con el optimizador AdamW (Loshchilov y Hutter, 2017) con los siguientes hiperparámetros:

Se usa un programa de tasa de aprendizaje de coseno tal que la tasa de aprendizaje final es igual al 10% de la tasa de aprendizaje máxima. El papel utiliza una caída de peso de 0,1 y un recorte de gradiente de 1,0. Use 2000 pasos de calentamiento y varíe la tasa de aprendizaje y el tamaño del lote con el tamaño del modelo.

Se utiliza la arquitectura basada en transformadores y se han realizado las siguientes 3 mejoras:

3.1 Prenormalización

Para mejorar la estabilidad del entrenamiento, la entrada de cada capa de transformador se normaliza en lugar de la salida.

Además, utilice la función de normalización RMS Norm. El nombre completo de RMS Norm es normalización de la capa Root Mean Square. En comparación con la norma de capa, la principal diferencia de la norma RMS es que se elimina la parte de restar el valor medio. La fórmula de cálculo es:

 El autor de RMS Norm cree que este modo simplifica el cálculo de Layer Norm y puede reducir el tiempo de cálculo entre un 7 % y un 64 % [3] .

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

        return (self.weight * hidden_states).to(input_dtype)

3.2 SwiGLU

SwiGLU se utiliza en lugar de ReLU como función de activación. A diferencia de PaLM, las dimensiones son 234� en lugar de 4�.

SwiGLU propuso en el artículo [4]  que, en comparación con otras variantes de la función de activación, se puede obtener el valor óptimo de perplejidad logarítmica (empatado con GEGLU).

class LlamaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
    ):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        # config 中 hidden_act = 'silu'
        # 'silu' 和 'swish' 对应的激活函数均为:SiLUActivation 
        # https://github.com/huggingface/transformers/blob/717dadc6f36be9f50abc66adfd918f9b0e6e3502/src/transformers/activations.py#L229
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        # 对应上述公式的 SwiGLU
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

Se puede ver en el código que hay 3 capas lineales en LlamaMLP. La razón es que la función de activación de SwiGLU requiere una capa lineal más para la activación que la función de activación similar a ReLU.

3.3 CUERDA

La idea central de RoPE es "realizar la codificación de posición relativa por medio de la codificación de posición absoluta". Se puede decir que tiene la conveniencia de la codificación de posición absoluta y puede expresar la relación de posición relativa entre diferentes tokens. [5]  A diferencia del artículo original de Transformers, que agrega la incrustación pos y la incrustación de fichas, RoPE multiplica el código de posición y la consulta (o clave). detalles de la siguiente manera:

 

# 代码增加了注释,可以看到和原始公式的对应关系。
class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        # 此处 inv_freq 对应公式中的 theta
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        # 此处 freqs 对应公式中的 m * theta, t 对应公式中的 m,表示位置
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # 此处和原始公式不同,theta_0 和 theta_0 不再相邻
        # 而是分在向量的前半部分和后半部分
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        # 大部分情况下,直接从这里返回
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    # 此次和原始推导中不同,正负号不是间隔的,而是分前半部分和后半部分。但对于结果没有影响
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    # 对应上图中 RoPE 的简化计算
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

4. Implementación eficiente

Entrenamiento acelerado:

  • Utilice  la atención multicabezal causal  para mejorar la velocidad de entrenamiento del modelo. La implementación de este mecanismo toma prestada la biblioteca xformers , y su idea no es almacenar el peso de atención, ni calcular el puntaje de atención.
  • La función de activación de Transformer se implementó manualmente sin usar el autograd de la librería pytorch para obtener una mejor velocidad de entrenamiento. Al mismo tiempo, la tecnología de paralelización se utiliza para mejorar la velocidad de entrenamiento.
  • Se redujo la cantidad de cálculos para volver a calcular la activación durante el control de activación. La implementación manual de la función de transferencia inversa de la capa del transformador ahorra activaciones computacionalmente costosas, como la salida de la capa lineal.
  • Reduzca el uso de la memoria de video usando el paralelismo de modelos y el paralelismo de secuencias.
  • Paralelice el cálculo de activaciones y la comunicación entre GPU tanto como sea posible.

Efecto de aceleración:

  • El modelo 65B puede alcanzar una velocidad de 380 tokens/seg/GPU en GPU 2048 80G A100. Se necesitan 21 días para entrenar tokens de 1.4T.

5. Principales resultados y conclusiones

LLaMA-13B supera a GPT-3, aunque solo 1/10 del tamaño. LLaMA-65B es un modelo que puede competir con los mejores LLM como Chinchilla-70B y PaLM-540B. Después del ajuste fino, el efecto de LLaMA se ha mejorado significativamente.

En el futuro, planeamos lanzar un modelo más grande entrenado previamente en un corpus más grande, porque a medida que aumentan los datos y el modelo, podemos ver una mejora constante en el rendimiento.

5. Implementación eficiente

Dirección de código abierto de LLaMA2: https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md

Dirección de descarga de LLaMA2: https://ai.meta.com/resources/models-and-libraries/llama-downloads/

Dirección del blog oficial de LLaMA2: https://ai.meta.com/resources/models-and-libraries/llama/

5.1 Registro de detalles técnicos

  1. Tamaño del modelo : LLaMA 2 ofrece tres tamaños de modelo diferentes: 7B, 13B y 70B. Entre ellos, la arquitectura de 7B y 13B es la misma que LLaMA 1, que se puede utilizar directamente en aplicaciones comerciales.
  2. Entrenamiento : el modelo LLaMA 2 se entrena en 2 billones de tokens con el doble de longitud de contexto que LLaMA 1. Además, el modelo de chat LLaMA-2 se entrena en más de 1 millón de nuevas anotaciones humanas. LLaMA 2 tiene un 40% más de corpus de entrenamiento que LLaMA 1, y la longitud del contexto se ha incrementado de 2048 a 4096, lo que le permite comprender y generar textos más largos.
  3. Entrenamiento previo : LLaMA 2 se entrena previamente utilizando datos en línea disponibles públicamente, seguido de un ajuste fino supervisado para crear una versión inicial de LLaMA-2-chat. A continuación, LLaMA-2-chat se refina de forma iterativa mediante el aprendizaje de refuerzo con retroalimentación humana (RLHF), que incluye muestreo de rechazo y optimización de políticas proximales (PPO).
  4. Arquitectura del modelo : LLaMA 2 adopta la mayoría de las configuraciones previas al entrenamiento y la arquitectura del modelo de LLaMA 1, usando la arquitectura de transformador estándar, aplicando la normalización previa usando RMSNorm, usando la función de activación SwiGLU y la posición de rotación RoPE incrustada. Las principales diferencias arquitectónicas con respecto a LLaMA 1 incluyen una mayor longitud del contexto y atención de consultas agrupadas (GQA).
  5. Atención de consultas agrupadas (GQA) : este es un nuevo mecanismo de atención que mejora la escalabilidad de inferencia de modelos grandes. Funciona compartiendo proyecciones clave y de valor entre múltiples cabezas sin reducir drásticamente el rendimiento. Se puede utilizar el formato de consulta múltiple original (MQA) con una sola proyección de KV o la variante Atención de consultas agrupadas (GQA) con una proyección de 8KV.
  6. Hiperparámetros : entrenar con el optimizador AdamW, donde β1=0.9, β2=0.95, eps=10−5. Usando un programa de tasa de aprendizaje de coseno, caliente para 2000 pasos y disminuya la tasa de aprendizaje final al 10% de la tasa de aprendizaje máxima. Utilice el decaimiento de peso de 0.1 y el recorte de gradiente de 1.0.
  7. Tokenizador : LLaMA 2 usa el mismo tokenizador que LLaMA 1; usa el algoritmo Byte Pair Encoding (BPE), implementado usando SentencePiece. Al igual que LLaMA 1, divida todos los números en dígitos individuales y use bytes para descomponer caracteres UTF-8 desconocidos. El vocabulario total es de 32k fichas.
  8. Ajuste fino : LLaMA 2-Chat es el resultado de meses de investigación experimental y la aplicación iterativa de técnicas de alineación, incluido el ajuste fino de instrucciones y RLHF, que requieren recursos informáticos y de anotación de datos masivos. La calidad de los datos de las instrucciones de ajuste fino supervisadas es muy importante, incluida la diversidad, centrándose en la privacidad y la seguridad, y no contiene ningún dato de metausuario.
  9. Seguridad : este estudio evaluó la seguridad de Llama 2 utilizando tres puntos de referencia de uso común, centrándose en tres dimensiones clave: autenticidad, que se refiere a si el modelo de lenguaje generará información de error, usando el punto de referencia TruthfulQA; toxicidad, que se refiere a si el modelo de lenguaje generar contenido "tóxico ", grosero y dañino, use el punto de referencia ToxiGen; sesgo, que se refiere a si el modelo de lenguaje producirá contenido sesgado, use el punto de referencia BOLD.

5.2 Comparación de papeles llama1&2

  • El documento de llama1 detalla cómo funcionan los datos de entrenamiento y el lector puede replicar su proceso de entrenamiento.
  • El documento de llama2 se centra principalmente en la introducción de la metodología de entrenamiento de modelos, y la transparencia de la parte de los datos es baja.
  • llama1 consiste en un 67 % de datos de rastreadores públicos y un 15 % de datos seleccionados.
  • Con respecto a los datos de llama2, se proporciona menos información en el documento, pero el meta menciona que el corpus de preentrenamiento de llama2 es un 40 % más grande que el de llama1, que puede ser similar al conjunto de datos de llama1.

5.3 Importancia de la calidad de los datos

  • llama2 aumenta el peso de alta credibilidad en el conjunto de datos, lo que resulta en una mayor credibilidad del modelo;
  • Limpiar y seleccionar mejores datos es un paso fundamental para garantizar que los modelos generen texto preciso;
  • Con el avance de métodos como la evaluación de la calidad de los datos y la deduplicación, se ha generado debate sobre si la cantidad de datos de entrenamiento seguirá aumentando;
  • La relación entre la cantidad de datos de preentrenamiento y Chinchilla Optimal: durante el proceso de entrenamiento, la cantidad de datos de preentrenamiento de llama2 supera el límite establecido por Chinchilla Optimal, y la curva de pérdida sigue descendiendo al final del entrenamiento. Esto plantea la pregunta de si Chinchilla Optimal todavía se aplica.

referencia

Entrenamiento Compute-Optimal Modelos de lenguaje grande  https://arxiv.org/abs/2203.15556

Normalización de la capa cuadrática media raíz  https://arxiv.org/pdf/1910.07467.pdf

Las variantes de GLU mejoran el transformador  https://arxiv.org/pdf/2002.05202.pdf

El camino hacia la actualización de Transformer: 2. Codificación de posición rotatoria que aprende de las fortalezas de otros - Espacios científicos |

Llama2|El que obtiene los datos gana el mundo——Una maravillosa compilación de debates de LLaMA2 sobre Latent Space- Zhihu (zhihu.com)

LLM/Meta's LLAMA-2__ Introducción y traducción en papel - Saber (zhihu.com)

Detalles técnicos de LLaMA2 RLHF - Zhihu (zhihu.com)

Supongo que te gusta

Origin blog.csdn.net/zwqjoy/article/details/131943655
Recomendado
Clasificación