An in-depth analysis of how LLaMA improves the underlying structure of Transformer

This article is shared from the Huawei Cloud Community " How much do you know about the underlying architecture of the large language model?" "Introduction to LLM model structure of LLM large underlying architecture ", author: Ma Shanghua_Lancer.

Large language model structure Most of the current large language model structures adopt a similar GPT architecture, using a network structure composed only of decoders based on the Transformer architecture, and using an autoregressive approach to build language models. However, they differ in details such as position encoding, layer normalized position, and activation function. The previous article introduced the training process of the GPT-3 model, including model architecture, training data composition, training process and evaluation methods.

Since GPT-3 does not have open source code, it is not easy to directly reproduce the entire training process according to the paper. Therefore, the system OPT (OpenPre-trained Transformer Language Models) is constructed and open sourced based on the description of the reproduction process in GPT-3. Meta AI has also open sourced the LLaMA model based on the GPT-3 architecture. The public evaluation results and the models using this model for supervised fine-tuning have performed very well. Since OpenAI is no longer open source and has no open source model since the GPT-3 model, it is not clear what model architecture is used by ChatGPT and GPT-4.

This article will take the LLaMA model as an example to introduce the improvement of the large language model architecture on the original structure of Transformer, and introduce the attention mechanism optimization method that accounts for the largest proportion of space and time in the Transformer model structure.

1. Model structure of LLaMA

The previous article introduced the Transformer structure and details used by LLaMA. The differences from the Transformer architecture introduced in this article include the use of pre-normalization (Pre-normalization) and the use of the RMSNorm normalization function ( Normalizing Function), the activation function is replaced by SwiGLU, and Rotation Position Embedding (RoP) is used. The overall Transformer architecture is similar to GPT-2, as shown in Figure 1.1.

Figure 1.1 GPT-2 model structure

Next, the specific content and implementation of the RMSNorm normalization function, SwiGLU activation function and rotated position embedding (RoPE) will be introduced respectively.

1.1. RMSNorm normalization function

In order to make the model training process more stable, GPT-2 introduces the pre-layer normalization method compared to GPT, moving the first layer normalization to before the multi-head self-attention layer, and the second layer normalization It has also been moved before the fully connected layer, and the position of the residual connection has also been adjusted to after the multi-head self-attention layer and the fully connected layer. The RMSNorm normalization function is also used in layer normalization. The calculation formula of the aRMSNorm function for the input vector is as follows:

In addition, RMSNorm can also introduce learnable scaling factors gi and offset parameters bi, thus obtaining

The code implementation of RMSNorm in the HuggingFace Transformer library is as follows:

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps # eps prevents the denominator from being 0 after taking the reciprocal
    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # weight is the trainable parameter multiplied at the end, that is, g_i
        return (self.weight * hidden_states).to(input_dtype)

1.2. SwiGLU activation function

The SwiGLU[50] activation function was proposed by Shazeer in the literature and has been widely used in models such as PaLM, and has achieved good results. Compared with the ReLU function, it has been improved a lot in most evaluations. In LLaMA, the calculation formula of the fully connected layer using FFN (Position-wise Feed-Forward Network) with SwiGLU activation function is as follows:

Among them, σ(x) is the Sigmoid function. Figure 1.2 shows the shape of the Swish activation function under different values ​​of parameter β. It can be seen that when β approaches 0, the Swish function approaches the linear function y = x. When β approaches infinity, the Swish function approaches the ReLU function. When β approaches 1, the Swish function is smooth and Not monotonic. In the Transformer library of HuggingFace, the Swish1 function is replaced by the silu function.

Figure 1.2 The shape of the Swish activation function under different values ​​of parameter β

1.3. Rotational Position Embedding (RoPE)

For position encoding, Rotary Positional Embeddings (RoPE) [52] is used instead of the original absolute position encoding. RoPE draws on the idea of ​​complex numbers, and its starting point is to implement relative position encoding through absolute position encoding. The goal is to add absolute position information to q, k through the following operations:

After the above operations, ˜qm and ˜kn contain the absolute position information of positions m and n.

Finally, we can get RoPE represented by complex numbers in two dimensions:

According to the geometric meaning of complex multiplication, the above transformation is actually the rotation of the corresponding vector, so the position vector is called "rotational position encoding". It can also be expressed in matrix form:

According to the property that the inner product satisfies linear superposition, any even-dimensional RoPE can be expressed as a splicing of the two-dimensional case, that is:

Due to the sparse nature of the above matrix Rn, the bitwise multiplication ⊗ operation can be used to further speed up the calculation. The code implementation of RoPE in the HuggingFace Transformer library is as follows:

class LlamaRotaryEmbedding(torch.nn.Module):

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device,
        dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation
        # in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
        
    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`.
        # Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation
            # in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype),
            persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype),
            persistent=False)
    
        return (
        self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )
    def rotate_half(x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
        cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
        sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
        sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

1.4. Overall model framework

The decoder layer can be implemented based on the above model and network structure. The process of using training corpus to model based on the autoregressive method is basically the same as the process introduced in this article. The specific hyperparameters used by LLaMA models of different scales are shown in Table 1.3. However, since the large language model has a very large number of parameters and requires a large amount of data for training, it is difficult to complete the training using only a single GPU and requires a distributed model training framework (relevant content will be described in detail in a later article).

Table 1.3 Specific hyperparameter details of LLaMA at different model sizes

The overall implementation code of the LLaMA decoder in the HuggingFace Transformer library is as follows:

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig):
    super().__init__()
    self.hidden_size = config.hidden_size
    self.self_attn = LlamaAttention(config=config)
    self.mlp = CallMLP(
        hidden_size=self.hidden_size,
        intermediate_size=config.intermediate_size,
        hidden_act=config.hidden_act,
    )
    self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Tuple[torch.Tensor]] = None,
            output_attentions: Optional[bool] = False,
            use_cache: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

    hidden_states = self.input_layernorm(hidden_states)
    # Self Attention
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
    )
    hidden_states = residual + hidden_states
    # Fully Connected
    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)
    hidden_states = self.mlp(hidden_states)
    hidden_states = residual + hidden_states
    outputs = (hidden_states,)
    if output_attentions:
        outputs += (self_attn_weights,)
    if use_cache:
        outputs += (present_key_value,)
    return outputs

2. Optimization of attention mechanism

In the Transformer structure, the time and storage complexity of the self-attention mechanism are squarely related to the length of the sequence, so it takes up a lot of computing device memory and consumes a lot of computing resources. Therefore, how to optimize the spatio-temporal complexity of the self-attention mechanism and enhance computational efficiency are important issues that large language models need to face. Some studies start from approximate attention, aiming to reduce attention calculation and memory requirements, and propose methods including sparse approximation, low-rank approximation and so on. In addition, there are also some studies starting from the characteristics of the computing acceleration device itself to study how to better utilize the hardware characteristics to efficiently calculate the attention layer in Transformer. This article will introduce the above two types of methods respectively.

2.1. Sparse attention mechanism

By analyzing the attention matrices in some trained Transformer models, we found that many of them are usually sparse, so the computational complexity can be reduced by limiting the number of Query-Key pairs. This type of method is called the sparse attention (SparseAttention) mechanism. Sparsification methods can be further divided into two categories: location information-based and content-based. The basic types of location-based sparse attention mechanisms are shown in Figure 2.6, which mainly include the following five types:

(1) Global Attention: In order to enhance the model's ability to model long-distance dependencies, some global nodes can be added;

(2) Band Attention: Most data is localized, limiting Query to only interact with a few adjacent nodes;

(3) Dilated Attention; similar to Dilated Conv in CNN, a larger receptive field is obtained by increasing gaps;

(4) Random Attention: Improve non-local interaction through random sampling;

(5) Block Local Attention: Use multiple non-overlapping blocks to limit information interaction.

Figure 2.1 Five basic types of location-based sparse attention

Existing sparse attention mechanisms are usually based on a composite model of the above five basic position-based sparse attention mechanisms. Figure 2.2 shows some typical sparse attention models.

Star-Transformer [54] uses a combination of strip attention and global attention. Specifically, Star-Transformer only includes a global attention node and a ribbon attention of width 3, where any two non-adjacent nodes are connected through a shared global attention, while adjacent nodes are directly connected.

Longformer uses a combination of strip attention and Internal Global-node Attention. In addition, Longformer also replaces some strip attention heads in the upper layer with attention with expansion windows, which increases the receptive field without increasing the amount of calculation. Extended Transformer Construction (ETC) utilizes a combination of strip attention and External Global-node Attention. ETC sparse attention also includes a masking mechanism to process structured input and uses Contrastive Predictive Coding (CPC) for pre-training.

BigBird uses banded and global attention, and also uses additional random attention to approximate fully connected attention. In addition, it is revealed that the use of sparse encoders and sparse decoders can simulate any Turing machine, which also explains to a certain extent , the reason why the sparse attention model can achieve better results.

Figure 2.2 Position-based composite sparse attention type

Content-based sparse attention is to create sparse attention based on input data. One of the simple methods is to select keys that are highly similar to a given query (Query). Routing Transformer uses the K-means clustering method to perform clustering on  and  together. The set of class center vectors is , where k is the number of class centers. Each Query only interacts with Keys in the same cluster. The center vector is updated using the sliding average method:

where |μ| represents the number of vectors in cluster μ. Reformer[60] uses the Local-Sensitive Hashing (LSH) method to select Key-Value pairs for each Query. The main idea is to use the LSH function to hash the Query and Key and divide them into multiple buckets. Increase the probability that Query and Key in the same bucket participate in interaction. Assuming that b is the number of buckets, given a random matrix R of size [Dk, b/2], the LSH function is defined as:

If hqi = hkj, qi can interact with the corresponding Key-Value pair.

2.2. FlashAttention

The memory (video memory) in NVIDIA GPU determines their speed, size and access restrictions according to whether they are physically inside the GPU chip or on the board RAM storage chip. GPU memory is divided into six major categories: global memory (Global memory), local memory (Local memory), shared memory (Shared memory, SRAM), register memory (Register memory), constant memory (Constant memory), and texture memory (Texture memory). kind. Figure 2.8 shows the overall structure of NVIDIA GPU memory. Among them, global memory, local memory, shared memory and register memory have read and write capabilities.

The High Bandwidth Memory (HBM) used by global memory and local memory is located on the board RAM memory chip, and this part of the memory has a large capacity. Global memory can be accessed by all threads, while local memory can only be accessed by the current thread. The global memory in NVIDIA H100 has 80GB of space. Although its access speed can reach 3.35TB/s, if all threads access the global memory at the same time, its average bandwidth is still very low. Shared memory and registers are located on the GPU chip, so the capacity is very small, and shared memory can only be shared and accessed by threads in the same GPU thread block (Thread Block), while registers can only be accessed within the same thread.

The shared storage capacity that each GPU thread block in NVIDIA H100 can use in the Stream Multi-processor (SM) is only 228KB, but its speed is very fast, much higher than the access speed of global memory.

Figure 2.2 Overall memory structure diagram of NVIDIA GPU

In Section 2.2 of this chapter, the principle of the self-attention mechanism is introduced. When computing in the GPU, the traditional method also needs to introduce: two intermediate matrices S and P and store them in the global memory. The specific calculation process is as follows:

According to the above calculation process, you need to first read the matrices Q and K from the global memory, and then write the calculated matrix S into the global memory. Then get the matrix S from the global memory, calculate Softmax to get the matrix P, and then write it into Global content, then read matrix
P and matrix V, and calculate matrix matrix O. Such a process will greatly occupy the bandwidth of the video memory. In self-attention mechanisms, calculation speed is much faster than memory speed, so computational efficiency is increasingly bottlenecked by global memory access.

FlashAttention uses the special design in GPU hardware to avoid reading or writing the attention matrix in HBM as much as possible based on the difference in I/O speed of global memory and shared storage.

The goal of FlashAttention is to use SRAM as efficiently as possible to speed up computation and avoid reading and writing the attention matrix from global memory. Achieving this goal requires the ability to calculate the softmax function without accessing the entire input, and the intermediate attention matrix cannot be stored in backpropagation.

In the standard Attention algorithm, Softmax calculation is performed row by row, that is, before matrix multiplication with V, each block of Q and K needs to complete the calculation of an entire row. After getting the result of Softmax, perform matrix multiplication with matrix V in blocks. In FlashAttention, the softmax calculation is performed incrementally by splitting the input into chunks and making multiple passes over the input chunks.

The standard implementation of the self-attention algorithm writes the matrices S and P during the calculation process into the global memory, and the size of these intermediate matrices is related to the length of the input sequence and is of quadratic type. Therefore, FlashAttention proposes not to use the intermediate attention matrix and reduce global memory consumption by storing normalization factors.

The FlashAttention algorithm does not write S and P as a whole into the global memory. Instead, it writes in blocks, stores the forward-passed Softmax normalization factor, and quickly recalculates the on-chip attention in the backward propagation. This is better than extracting the global content from the global content. The standard method of reading the intermediate attention matrix is ​​faster. Because global memory accesses are greatly reduced, it runs faster and uses less memory, even though recalculation results in increased FLOPs. The specific algorithm is shown in code 2.2. The calculations corresponding to the inner loop and outer loop can be referred to the figure below.

2.3 FlashAttention calculation flow chart

2.3. Multi-query attention

Multi Query Attention [62] is a variant of multi-head attention. The main difference is that in multi-query attention, different attention heads share a set of keys and values, and each head only retains a separate copy of query parameters.

Therefore, there is only one matrix of keys and values, which greatly reduces the memory usage and makes it more efficient. Since multi-query attention changes the structure of the attention mechanism, models usually need to support multi-query attention from the beginning of training. The research results in [63] show that multi-query attention support can be added by fine-tuning the already trained model, and only about 5% of the original training data is needed to achieve good results. Many models, including Falcon, SantaCoder, StarCoder, etc., use multi-query attention mechanisms.

Taking LLM Foundry as an example, the multi-query attention implementation code is as follows:

Code 2.2: FlashAttention algorithm. In simple terms, I will sort out the logic:

class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        device: Optional[str] = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.Wqkv = nn.Linear( # Multi-Query Attention Create
            d_model,
            d_model + 2 * self.head_dim, # Only create the head vector of the query, so there is only 1 d_model
            device=device, # The keys and values ​​share their own vector of head_dim
        )
        self.attn_fn = scaled_multihead_dot_product_attention
        self.out_proj = nn.Linear(
            self.d_model,
            self.d_model,
            device=device
        )
        self.out_proj._is_residual = True # type: ignore
    def forward(
        self,
        x,
    ):
        qkv = self.Wqkv(x) # (1, 512, 960)
        query, key, value = qkv.split( # query -> (1, 512, 768)
            [self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
            dim=2 # value -> (1, 512, 96)
        )
        context, attn_weights, past_key_value = self.attn_fn(
            query,
            key,
            value,
            self.n_heads,
            multiquery=True,
    )
        return self.out_proj(context), attn_weights, past_key_value

Compared with the multi-head self-attention code implemented in LLM Foundry, the difference is only in the establishment of the Wqkv layer:

# Multi Head Attention
self.Wqkv = nn.Linear( # Multi-Head Attention creation method
    self.d_model,
    3 * self.d_model, # 3 matrices of query, key and value, so 3 * d_model
    device=device
)
query, key, value = qkv.chunk( # Each tensor is (1, 512, 768)
    3,
    dim=2
)
# Multi Query Attention
self.Wqkv = nn.Linear( # Multi-Query Attention creation method
    d_model,
    d_model + 2 * self.head_dim, # Only create the head vector of the query, so it is 1* d_model
    device=device, # Key and value no longer have separate header vectors
)
query, key, value = qkv.split( # query -> (1, 512, 768)
    [self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
    dim=2 # value -> (1, 512, 96)
)

This article will take the LLaMA model as an example to introduce in detail the improvement of the large language model architecture on the original structure of Transformer from the bottom level, and introduce the attention mechanism optimization method that accounts for the largest proportion of space and time in the Transformer model structure. It does seem to be a bit more "dry", but only by better understanding the principles of large models from the bottom can we know better how to use them.

Click to follow and learn about Huawei Cloud’s new technologies as soon as possible~

IntelliJ IDEA 2023.3 & JetBrains Family Bucket annual major version update new concept "defensive programming": make yourself a stable job GitHub.com runs more than 1,200 MySQL hosts, how to seamlessly upgrade to 8.0? Stephen Chow's Web3 team will launch an independent App next month. Will Firefox be eliminated? Visual Studio Code 1.85 released, floating window US CISA recommends abandoning C/C++ to eliminate memory security vulnerabilities Yu Chengdong: Huawei will launch disruptive products next year and rewrite industry history TIOBE December: C# is expected to become the programming language of the year A paper written by Lei Jun 30 years ago : "Principle and Design of Computer Virus Determination Expert System"
{{o.name}}
{{m.name}}

Guess you like

Origin my.oschina.net/u/4526289/blog/10319861