Llama explains in simple terms

Llama explains in simple terms

The following article comes from Algorithm Gourmet House, author Liang Yun 1991

Algorithmic Gourmet House.

Once upon a time~ there was a persevering foodie~ made a complicated algorithm~ into gourmet food~

Dry goods warning ahead: This may be the easiest and most practical tutorial you can find to learn the source code of the open source LLM model .

This example builds and interprets the source code of the Llama model module by module based on the transformers library from scratch (Chinese can be translated into alpaca).

And train it to implement an interesting example: the sum of two numbers.

The input and output are similar to the following:

Input: "12345+54321="

Output: "66666"

We treat this task as a text generation task. The input is the upper half of a sequence and the lower half is output.

This is similar to the input and output structure of text generation, so it can be done with Llama.

At present, most open source LLM models are based on the transformers library, and most of their structures are similar to Llama.

As the saying goes, the devil hides in the details, and a deep understanding of the source code details will help you understand the basic principles related to the open source LLM model (such as rotation position encoding and length extrapolation), and make you familiar with various parameters Configuration and use (such as past_key_value , use of attention_mask, etc.).

1. Prepare data

import random

import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader

# 定义字典
words = '<PAD>,<BOS>,<EOS>,1,2,3,4,5,6,7,8,9,0,+,='
vocab = {word: i for i, word in enumerate(words.split(','))}
vocab_r = [k for k, v in vocab.items()] #反查词典

#两数相加数据集
def get_data(min_length=10,max_length=20):
    # 定义词集合
    words = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

    # 每个词被选中的概率
    p = np.array([7, 5, 5, 7, 6, 5, 7, 6, 5, 7])
    p = p / p.sum()

    # 随机采样n1个词作为s1
    n1 = random.randint(min_length, max_length)
    s1 = np.random.choice(words, size=n1, replace=True, p=p)
    s1 = s1.tolist()

    # 随机采样n2个词作为s2
    n2 = random.randint(min_length, max_length)
    s2 = np.random.choice(words, size=n2, replace=True, p=p)
    s2 = s2.tolist()

    # x等于s1和s2字符上的相加
    x = s1 + ['+'] + s2 + ['=']
    
    # y等于s1和s2数值上的相加
    y = int(''.join(s1)) + int(''.join(s2))
    y = list(str(y))
    
    # 加上首尾符号
    x = ['<BOS>'] + x 
    y =  y + ['<EOS>']
    
    return x,y

x,y = get_data() 
print(''.join(x)+''.join(y),"\n")


<BOS>3914835626735057733+318829464988=3914835945564522721<EOS>
# 定义数据集
class TwoSumDataset(torch.utils.data.Dataset):
    def __init__(self,size = 100000, min_length=10,max_length=20):
        super(Dataset, self).__init__()
        self.size = size
        self.min_length=min_length
        self.max_length=max_length

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        x,y = self.get(i)
        
        # 编码成token
        context_ids = [vocab[i] for i in x]
        target_ids = [vocab[i] for i in y]
        
        input_ids = context_ids + target_ids
        
        #-100标志位后面会在计算loss时会被忽略不贡献损失,我们集中优化target部分生成的loss
        labels = [-100]*len(context_ids)+ target_ids
        masks = [0 if t==vocab['<PAD>'] else 1 for t in input_ids]
        
        example = {'input_ids':input_ids,
                  'labels':labels,'attention_mask':masks}
        
        return example
    
    def get(self,i):
        return get_data(self.min_length,self.max_length)
    
    
    def show_example(self,example):
        input_ids,labels = example['input_ids'],example['labels']
        x = ''.join([vocab_r[a] for a,b in zip(input_ids,labels) if b==-100])
        y = ''.join([vocab_r[a] for a,b in zip(input_ids,labels) if b!=-100])
        print(x+y)
        
        
    
ds_train = TwoSumDataset(size = 100000,min_length=10,max_length=20)
ds_val = TwoSumDataset(size = 10000,min_length=10,max_length=20)
example = ds_train[0]
ds_train.show_example(example)

<BOS>12878683929048906366+11274414130675477=12889958343179581843<EOS>
def data_collator(examples: list):
    len_ids = [len(example["input_ids"]) for example in examples]
    longest = max(len_ids) #之后按照batch中最长的input_ids进行padding
    
    input_ids = []
    labels_list = []
    masks_list = []
    
    for length, example in sorted(zip(len_ids, examples), key=lambda x: -x[0]):
        ids = example["input_ids"]
        labs = example["labels"]
        masks = example['attention_mask']
        
        ids = [vocab['<PAD>']] * (longest - length)+ids 
        labs = [-100] * (longest - length)+labs
        masks = [0]*(longest - length)+masks
        
        input_ids.append(torch.LongTensor(ids))
        labels_list.append(torch.LongTensor(labs))
        masks_list.append(torch.LongTensor(masks))
          
    input_ids = torch.stack(input_ids)
    labels = torch.stack(labels_list)
    attention_mask = torch.stack(masks_list)
    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask":attention_mask
    }

# 数据加载器
dl_train = DataLoader(dataset=ds_train,
         batch_size=200,
         drop_last=True,
         shuffle=True,
         collate_fn = data_collator        
        )

dl_val = DataLoader(dataset=ds_val,
         batch_size=200,
         drop_last=True,
         shuffle=False,
         collate_fn = data_collator  
        )


for batch in dl_train:
    break 
batch 
{'input_ids': tensor([[ 1, 11,  6,  ...,  7, 11,  2],
         [ 0,  1,  6,  ...,  5,  4,  2],
         [ 0,  1,  7,  ...,  8,  8,  2],
         ...,
         [ 0,  0,  0,  ..., 10, 11,  2],
         [ 0,  0,  0,  ..., 12,  3,  2],
         [ 0,  0,  0,  ..., 11, 12,  2]]),
 'labels': tensor([[-100, -100, -100,  ...,    7,   11,    2],
         [-100, -100, -100,  ...,    5,    4,    2],
         [-100, -100, -100,  ...,    8,    8,    2],
         ...,
         [-100, -100, -100,  ...,   10,   11,    2],
         [-100, -100, -100,  ...,   12,    3,    2],
         [-100, -100, -100,  ...,   11,   12,    2]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [0, 1, 1,  ..., 1, 1, 1],
         [0, 1, 1,  ..., 1, 1, 1],
         ...,
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1]])}

Second, define the model

Next, we will build the LLaMA model from low to high like building blocks to build a castle.

First build 4 basic components: rotation position encoding, multi-head attention, feedforward network, layer normalization. Modules such as walls, roofs, doors, and windows are built with the most basic building blocks.

Then use these 4 basic components to build intermediate products: the decoding layer. It's like building a room out of basic components.

Then, a complete model of LlamaModel is assembled by stacking multiple intermediate decoding layers, which is equivalent to building the main structure of the castle by building multiple rooms.

Finally, we designed two different output heads based on LlamaModel, one is the language model Head, and obtained LlamaForCausalLM, which can be used for text generation.

The other is the classification head, which has obtained LlamaForSequenceClassification, which can be used for text classification.

It is equivalent to designing two different decoration styles based on the completion of the main structure of the castle. One is to install some recreational facilities for commercial activities, and the other is to install some weapons for military activities. .


1, Rotary position encoding: RoPE ( absolute position encoding realized by using rotation matrix , can achieve the effect of relative position encoding)

2, Multi-head attention: LlamaAttention (used to fuse information between different tokens )

3, Feedforward network: LlamaMLP ( for high-dimensional mapping transformation of information after multi-head attention fusion by position)

4, Layer normalization: LlamaRMSNorm (used to stabilize the input, which is equivalent to keeping the direction of each word vector unchanged, but standardizing the modulus length. )


5, Llama decoding layer: LlamaDecoderLayer ( basic structural unit with information fusion and information conversion functions at the same time )


6, Llama decoder: LlamaModel ( stacking of multiple decoding layers )


7. Llama language model: LlamaForCausalLM (decoder plus language model head, which can be used for text generation)

8. Llama classification model: LlamaForSequenceClassification (decoder plus classification head, can be used for text classification)


import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings

from transformers.models.llama.configuration_llama  import LlamaConfig
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING,LLAMA_START_DOCSTRING

logger = logging.get_logger('llama')

config = LlamaConfig(
    vocab_size=len(vocab),
    hidden_size=512,
    intermediate_size=2752,
    num_hidden_layers=8,
    num_attention_heads=16,
    hidden_act='silu',
    max_position_embeddings=128,
    initializer_range=0.02,
    rms_norm_eps=1e-06,
    use_cache=True,
    pad_token_id=0,
    bos_token_id=1,
    eos_token_id=2,
    tie_word_embeddings=False
) 

1. Rotary position encoding RoPE

Rotary position encoding uses a rotation matrix to represent the position encoding (Rotary Position Encoding), referred to as RoPE.

The three core points of knowledge about RoPE are as follows:

  • The design idea of ​​RoPE is to use absolute position encoding to achieve the effect of relative position encoding.

  • RoPE is implemented by using a rotation matrix to represent absolute position encoding.

  • Using the NTK extension method allows RoPE to train on short text and make predictions on long text.

Reference article:

" Rotary position coding that learns from others " https://kexue.fm/archives/8265

"RoPE is a hexadecimal encoding" https://kexue.fm/archives/9675

(1) Absolute position coding and relative position coding

Position coding can generally be divided into absolute position coding and relative position coding.

The advantage of absolute position encoding is that the calculation is simple and efficient, but the disadvantage is that the general effect is not as good as that of relative position encoding.

The advantage of relative position coding is that the effect is better, but the disadvantage is that the calculation efficiency is not as good as absolute position coding.

Absolute position encoding:

Relative position encoding:

In relative position encoding, the result of the attention weight is only related to the relative position of the token vector involved in the attention calculation, not directly related to the absolute position.

This is in line with the fact that the NLP field has translation invariance in the direction of the sequence length, so the relative position encoding is generally better than the absolute position encoding.

However, absolute position coding is not useless. Absolute position coding only needs to assign a position code to each position of the sequence (the number is proportional to the sequence length) during initialization, and no subsequent intervention is required.

However, the relative position encoding needs to obtain many relative positions (the number is proportional to the square of the sequence length) in the calculation process.

Therefore, absolute position coding is simpler and more efficient.

(2) Use rotation matrix to represent position encoding

It can be seen from the above discussion that absolute position coding and relative position coding have advantages and disadvantages, so is there any way to learn from each other?

Yes, this method is RoPE, and its design idea is to use absolute position encoding to achieve the effect of relative position encoding.

So how does the rotary position encoding use the absolute position encoding to achieve the effect of the relative position encoding? The answer is to use a rotation matrix to represent the positional encoding.

Among them is the rotation matrix, which satisfies the property . So, there are:

 Complies with the relative positional encoding form.

perfect! We use absolute position encoding to achieve the effect of relative position encoding.

So, what does the rotation matrix look like?

In two dimensions it looks like this.

In the NLP field, the dimensionality of word vectors is generally very high (for example, 4096).

Using the block idea of ​​the matrix, it can be proved that in the high-dimensional case, the expansion into the following form still satisfies the property of the rotation matrix

Among them , that is, the higher the dimension, the smaller the coefficient of the trigonometric function, the larger the period, and the slower the change.

Since the rotation matrix is ​​a sparse matrix, it would be a waste of computing power to directly use the multiplication calculation. The rotation position encoding process can be simplified from the matrix multiplication operation to the sum of the Hadamard products of two vectors.

(3) Length extension of rotary position encoding

In the application of LLM, there is a very important parameter called the context length supported by LLM (max context length).

A longer context length allows us to have more rounds of dialogue, allows us to conduct summary analysis of longer papers, and also allows us to generate longer articles.

But when training LLM, most of our training corpus is not long enough. The maximum text length designed for many LLM training is only 2k, which is the longest 2048 tokens.

So, can you use shorter texts during training and extend to long texts during inference?

It is possible that we can extend the length of RoPE.

We introduce 3 extension scenarios.

The first is direct extrapolation: direct extrapolation is actually continuing to use the existing position encoding formula without any modification.

When the extension length is not too long, such as extending from 2k to 2.5k , this method may have little impact on performance.

Because the rotation position code is only related to the size of the relative position mn , it generally has a long-distance attenuation, that is, the correlation between two tokens with a larger relative distance is generally weaker.

Therefore, if our model has learned a suitable attenuation law of the correlation between tokens relative to the relative distance between 0-2k from the training data, it is conceivable that applying this law to 0-2.5k is not too big of the problem.

But if we want to extend to a longer length, such as extending from 2k to 32k, this direct extrapolation scheme usually seriously affects performance. Because the attenuation law we learned may be completely attenuated and truncated to 0 at 5k, so we cannot capture the interaction between two tokens whose relative distance is longer than 5k, and extrapolation will lead to performance degradation.

To summarize, the use of direct extrapolation to the attenuation law over long distances is prone to problems, leading to performance degradation.

In order to reduce the impact of length extrapolation on performance, we can fine-tune the trained model in a few steps on longer contexts.

The second is linear interpolation: linear interpolation needs to change the position encoding formula, which is equivalent to reducing the position serial number proportionally.

The encoding formula changes like , when expanding from 2k to 32k, it is equivalent to changing the position number to 1/16 of the original.

Linear interpolation does not change the application range of the attenuation law learned by the model, and its effect is generally better than the direct extrapolation scheme without considering fine-tuning.

However, when the expansion factor is very large, such as extending from 2k to 32k, its performance will be significantly affected.

Because in this case, the use of the attenuation law in the case of short distances will be seriously affected. The two tokens with a distance of 1 are equivalent to a distance of 1/16 after the length is extended, and the attenuation law is in the short distance. The distance can have a very large rate of change, so the evaluation of the correlation can deviate extremely from reasonable values.

Fine-tuning a few steps over long text can also improve performance significantly when linear interpolation is applied.

The third is the NTK extension method: this method combines the advantages of extrapolation and interpolation, and can maintain good performance even without fine-tuning after length extension.

From the previous analysis, we know that the use of direct extrapolation to the attenuation law in the long-distance situation is prone to problems, and the use in the short-distance situation is not affected.

However, linear interpolation is prone to problems in the use of the attenuation law in the case of short distances, and has little influence in the case of long distances.

Can we combine them, with extrapolation properties in the short distance case (basically the same as before the extension), and interpolation properties in the long distance case (scaled to the range before the extension), so that the long distance case and The use of the decay law at short distances is not affected much.

We observe the element calculation formula in the first line of the RoPE position code, and we can find that the larger the value, the smaller the angular frequency coefficient corresponding to the trigonometric function, or the lower the frequency, the slower the corresponding trigonometric function changes.

It is easy to get the following intuitive conclusions: the difference between short distances (such as the difference between 1 and 5) is mainly reflected in the high frequency component (i is relatively small), the difference between long distances (such as the difference between 5000 and 10000), mainly It is reflected in the low frequency component (i is relatively large).

In order to have extrapolation characteristics in the case of short distances and interpolation characteristics in the case of long distances, we can design a scaling factor related to the position number so that the value is 1 at the highest frequency () (the same as before the expansion Basically the same), and at the lowest frequency () is just the reciprocal of the zoom factor (zoomed to the range before expansion).

An effective option is the exponential function of , whose effect is equivalent to a scaling of , and the appropriate scaling factor is easily obtained according to the boundary conditions.

The main points of the NTK extension method are high-frequency extrapolation and low-frequency interpolation . The implementation method is to directly scale the base number , similar to base encoding conversion .

Using NTK to extend to long text, even without fine-tuning, the performance will only decrease slightly.

The following is the implementation of RoPE and three length extension methods.

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False) #persistent=False将不会作为state_dict

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        #超过预设的max_position_embeddings则重新计算更大的Rope缓存,否则直接在缓存上切片
        if seq_len > self.max_seq_len_cached: 
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

    
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        t = t / self.scaling_factor #线性内插相当于将位置序号等比例缩小

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)


class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len

        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))  #NTK扩展方式直接对base进行缩放
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
            self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        
        #此处处理逻辑与原始的ROPE有差异,原始逻辑如下
        #emb = torch.cat((freqs, freqs), dim=-1)
        #emb[...,0::2]=freqs
        #emb[...,1::2]=freqs
        
        
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
        
        
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    
    #此处逻辑与原始的ROPE有所差异,原始逻辑如下
    #x1 = x[..., 0::2] 
    #x2 = x[..., 1::2]
    #res = torch.cat((x1, x2), dim=-1)
    #res[...,0::2]=-x2
    #res[...,1::2]=x1
    #return res
    
    x1 = x[..., : x.shape[-1] // 2] 
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

x = torch.randn(1,8,4,2)
rope = LlamaRotaryEmbedding(dim=8)
cos,sin = rope.forward(x,seq_len=4)
print(cos.shape) 
print(cos)
torch.Size([1, 1, 4, 8])
tensor([[[[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
            1.0000],
          [ 0.5403,  0.9950,  0.9999,  1.0000,  0.5403,  0.9950,  0.9999,
            1.0000],
          [-0.4161,  0.9801,  0.9998,  1.0000, -0.4161,  0.9801,  0.9998,
            1.0000],
          [-0.9900,  0.9553,  0.9996,  1.0000, -0.9900,  0.9553,  0.9996,
            1.0000]]]])

2. Multi-head attention LlamaAttention

The LlamaAttention here is basically the same as in the "Attention Is All You Need" paper. The main differences are as follows.

1. The number of heads of k and v can be a fraction of the number of heads of q. Similar to the idea of ​​​​group convolution, the parameter scale can be reduced.

2. Rope position encoding is performed every time when multi-head attention is done, instead of only once when the original paper is input.

3. It is allowed to cache past_key_value of the states of the incoming key and value, which can reduce repeated calculations in multiple rounds of dialogue and have an acceleration effect.

4. The attention_mask is applied to the attention matrix before softmax through addition.

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self._init_rope()

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value
    
    

3. Feedforward network LlamaMLP

The feedforward network is a 2-layer perceptron MLP.

First from the hidden_size dimension up_proj to the intermediate_size dimension, and then down_proj restores to the hidden_size dimension.

The main feature here is the introduction of a gate_proj with the activation function to achieve a gated attention function.

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        if self.config.pretraining_tp > 1:
            slice = self.intermediate_size // self.config.pretraining_tp
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)

            gate_proj = torch.cat(
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
            )
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
            down_proj = [
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
            ]
            down_proj = sum(down_proj)
        else:
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj

4. Layer normalization LlamaRMSNorm

The layer normalization here is called RMSNorm, which is slightly different from the standard LayerNorm.

The first is the RootMeanSquare that did not remove the mean, and directly divided, and then did not add bias.

These two small corrections can ensure that normalization at the layer will not change the direction of the word vector corresponding to hidden_states, but only change its modulus length.

reasonable in a certain sense.

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
    

5, Llama decoding layer

The decoding layer LlamaDecoderLayer is composed of LlamaAttention, LlamaMLP, and two LlamaRMSNorm, and uses a residual structure twice.

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = LlamaAttention(config=config)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs

6. Llama decoder

LlamaModel is stacked by multiple Llama decoding layers.

There are a few key points to understand:

1. _make_causal_maskIt is used to construct the mask structure of the lower triangle to realize the one-way attention of the language model.

2. _expand_maskIt is used to expand the incoming mask information related to special symbols into the same tensor structure as the attention matrix.

3. Setting gradient_checkpointing=True can save video memory. It mainly uses the torch.utils.checkpoint.checkpoint method. Its principle is very simple. When the decoder_layer is forwarded, the intermediate activation value is not saved to save video memory. When the backward is recalculated, the relevant value is recalculated, thereby exchanging time for space.

4. Gradient_checkpointing and use_cache cannot be set to True at the same time. The former is to save video memory time for space, and the latter is to save time and space for time.

# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, 
    device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)


# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
    config_class = LlamaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["LlamaDecoderLayer"]
    _skip_keys_device_placement = "past_key_values"

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, LlamaModel):
            module.gradient_checkpointing = value


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                inputs_embeds.device
            )
            combined_attention_mask = (
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
            )

        return combined_attention_mask

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        seq_length_with_past = seq_length
        past_key_values_length = 0

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )

        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, output_attentions, None)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    position_ids,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

7. Llama language model

The Llama language model LlamaForCausalLM is based on the Llama decoder LlamaModel by adding an lm_head as a Generator.

Thus a complete language model is realized.

In addition, the Llama language model also implements the following important functions.

1. Loss calculation function. When the labels are passed in the forward method, the cross-entropy loss of the language model is automatically calculated. Note that -100 in labels will be ignored and not involved in the calculation.

2. Text generation generate method. This method is inherited from PreTrainedModel. You can set model.generation_config.num_beams to select the beam width of beam search. The default is 1, which is greedy search.

_CONFIG_FOR_DOC = "LlamaConfig"

class LlamaForCausalLM(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -1].unsqueeze(-1)

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past

8. Llama classification model

LlamaForSequenceClassification is a sequence classification model.

This classification model can be used to train the Reward model in the RLHF process.

@add_start_docstrings(
    """
    The LLaMa Model transformer with a sequence classification head on top (linear layer).

    [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    """,
    LLAMA_START_DOCSTRING,
)
class LlamaForSequenceClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = LlamaModel(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        logits = self.score(hidden_states)

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            sequence_lengths = -1
        else:
            if input_ids is not None:
                sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
                    logits.device
                )
            else:
                sequence_lengths = -1

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(pooled_logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(pooled_logits, labels)
        if not return_dict:
            output = (pooled_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

Three, training model

Next, let's train a LlamaForCausalLM to implement the task of summing two numbers.

config = LlamaConfig(
    vocab_size=len(vocab),
    hidden_size=512,
    intermediate_size=2752,
    num_hidden_layers=8,
    num_attention_heads=16,
    num_key_value_heads=4,
    rope_scaling = None,
    hidden_act='silu',
    max_position_embeddings=128,
    initializer_range=0.02,
    rms_norm_eps=1e-06,
    use_cache=True,
    pad_token_id=0,
    bos_token_id=1,
    eos_token_id=2,
    tie_word_embeddings=False,
    pretraining_tp = 1,
    max_new_tokens = 100
) 

#试算一下
model = LlamaForCausalLM(config)
out = model.forward(**batch)
print(out.loss)

tensor(2.7630, grad_fn=)

from torchkeras import KerasModel 
from accelerate import Accelerator 

class StepRunner:
    def __init__(self, net, loss_fn, accelerator=None, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator if accelerator is not None else Accelerator() 
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
    
    def __call__(self, batch):
        
        #loss
        with self.accelerator.autocast():
            loss = self.net(**batch).loss

        #backward()
        if self.stage=="train" and self.optimizer is not None:        
            self.accelerator.backward(loss)
            if self.accelerator.sync_gradients:
                self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses (or plain metrics that can be averaged)
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        #metrics (stateful metrics)
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
KerasModel.StepRunner = StepRunner 


keras_model = KerasModel(model,loss_fn = None,
        optimizer=torch.optim.AdamW(model.parameters(),lr=3e-5))


#加载 之前训练过的权重
ckpt_path = 'llama_twosum'

keras_model.fit(train_data = dl_train,
                val_data = dl_val,
                epochs=100,patience=5,
                monitor='val_loss',mode='min',
                ckpt_path = ckpt_path,
                mixed_precision='fp16'
               )

picture

Fourth, use the model

from transformers.generation.utils import GenerationConfig
model.generation_config = GenerationConfig.from_dict({'num_beams':1,
                            'max_new_tokens':100,
                            'max_length':200})
model.generation_config.num_beams=1
model.generation_config.max_new_tokens = 100 
model.generation_config.max_length=200
def get_ans(tensor) ->"str":
    s = "".join([vocab_r[i] for i in tensor.tolist()])
    ans = s[s.find('=')+1:s.find('<EOS>')].replace('<BOS>','').replace('<EOS>','')
    return ans
x,y = get_data() 
print('x: '+''.join(x).replace('<BOS>',''))
print('y: '+''.join(y).replace('<EOS>',''))
x: 3481340050+90157504501803=
y: 90160985841853
input_ids = torch.tensor([[vocab[i] for i in x]]) 
out = model.generate(inputs=input_ids)
out 

tensor([[ 1,  5,  6, 10,  3,  5,  6, 12, 12,  7, 12, 13, 11, 12,  3,  7,  9,  7, 12,  6,  7, 12,  3, 10, 12,  5, 14, 11, 12,  3,  8, 12, 11, 10,  7, 10, 6,  3, 10,  7,  5,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2, 2,  2,  2,  2,  2,  2,  2,  2,  2,  2, 12,  2,  2,  2,  2,  2,  2,  2, 2, 12,  3, 12,  3]])

get_ans(out[0])

'90160985841853'

V. EVALUATION MODEL

from tqdm import tqdm 
loop = tqdm(range(1,201))
correct = 0
for i in loop:
    x,y = get_data() 
    input_ids = torch.tensor([[vocab[i] for i in x]]) 
    out = model.generate(inputs=input_ids)
    pred = get_ans(out[0])
    gt = ''.join(y).replace('<EOS>','')
    if pred==gt:
        correct+=1
    loop.set_postfix(acc = correct/i)
    
print("acc=",correct/len(loop))

acc= 0.99

Nice, our test is 99% accurate!

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/132197855