라마는 간단한 용어로 설명합니다

라마는 간단한 용어로 설명합니다

다음 기사는 Algorithm Gourmet House, 저자 Liang Yun 1991에서 발췌한 것입니다.

알고리즘 고메 하우스.

옛날 옛적에~ 끈질긴 미식가가 있었다~ 복잡한 알고리즘을 만들어서~ 미식으로~

사전 경고: 이것은 오픈 소스 LLM 모델의 소스 코드를 배우기 위해 찾을 있는 가장 쉽고 실용적인 튜토리얼 일 수 있습니다.

이 예제는 처음부터 Transformers 라이브러리를 기반으로 Llama 모델 모듈의 소스 코드를 모듈별로 빌드하고 해석합니다 (중국어는 알파카로 번역될 수 있음).

그리고 두 숫자의 합이라는 흥미로운 예를 구현하도록 훈련시킵니다.

입력 및 출력은 다음과 유사합니다.

입력: "12345+54321="

출력: "66666"

이 작업을 텍스트 생성 작업으로 취급합니다. 입력은 시퀀스의 위쪽 절반이고 아래쪽 절반은 출력입니다.

이는 텍스트 생성의 입력 및 출력 구조와 유사하므로 Llama로 수행할 수 있습니다.

현재 대부분의 오픈 소스 LLM 모델은 변환기 라이브러리를 기반으로 하며 대부분의 구조는 Llama와 유사합니다.

속담처럼 악마는 세부 사항에 숨어 있습니다.Llama 모델의 소스 코드 세부 사항 에 대한 깊은 이해는 오픈 소스 LLM 모델과 관련된 기본 원칙(예: 회전 위치 인코딩 및 길이 외삽)을 이해하는 데 도움이 되며, 다양한 매개변수에 익숙해지도록 구성 및 사용(예: past_key_value , Attention_mask 사용 등).

1. 데이터 준비

import random

import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader

# 定义字典
words = '<PAD>,<BOS>,<EOS>,1,2,3,4,5,6,7,8,9,0,+,='
vocab = {word: i for i, word in enumerate(words.split(','))}
vocab_r = [k for k, v in vocab.items()] #反查词典

#两数相加数据集
def get_data(min_length=10,max_length=20):
    # 定义词集合
    words = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

    # 每个词被选中的概率
    p = np.array([7, 5, 5, 7, 6, 5, 7, 6, 5, 7])
    p = p / p.sum()

    # 随机采样n1个词作为s1
    n1 = random.randint(min_length, max_length)
    s1 = np.random.choice(words, size=n1, replace=True, p=p)
    s1 = s1.tolist()

    # 随机采样n2个词作为s2
    n2 = random.randint(min_length, max_length)
    s2 = np.random.choice(words, size=n2, replace=True, p=p)
    s2 = s2.tolist()

    # x等于s1和s2字符上的相加
    x = s1 + ['+'] + s2 + ['=']
    
    # y等于s1和s2数值上的相加
    y = int(''.join(s1)) + int(''.join(s2))
    y = list(str(y))
    
    # 加上首尾符号
    x = ['<BOS>'] + x 
    y =  y + ['<EOS>']
    
    return x,y

x,y = get_data() 
print(''.join(x)+''.join(y),"\n")


<BOS>3914835626735057733+318829464988=3914835945564522721<EOS>
# 定义数据集
class TwoSumDataset(torch.utils.data.Dataset):
    def __init__(self,size = 100000, min_length=10,max_length=20):
        super(Dataset, self).__init__()
        self.size = size
        self.min_length=min_length
        self.max_length=max_length

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        x,y = self.get(i)
        
        # 编码成token
        context_ids = [vocab[i] for i in x]
        target_ids = [vocab[i] for i in y]
        
        input_ids = context_ids + target_ids
        
        #-100标志位后面会在计算loss时会被忽略不贡献损失,我们集中优化target部分生成的loss
        labels = [-100]*len(context_ids)+ target_ids
        masks = [0 if t==vocab['<PAD>'] else 1 for t in input_ids]
        
        example = {'input_ids':input_ids,
                  'labels':labels,'attention_mask':masks}
        
        return example
    
    def get(self,i):
        return get_data(self.min_length,self.max_length)
    
    
    def show_example(self,example):
        input_ids,labels = example['input_ids'],example['labels']
        x = ''.join([vocab_r[a] for a,b in zip(input_ids,labels) if b==-100])
        y = ''.join([vocab_r[a] for a,b in zip(input_ids,labels) if b!=-100])
        print(x+y)
        
        
    
ds_train = TwoSumDataset(size = 100000,min_length=10,max_length=20)
ds_val = TwoSumDataset(size = 10000,min_length=10,max_length=20)
example = ds_train[0]
ds_train.show_example(example)

<BOS>12878683929048906366+11274414130675477=12889958343179581843<EOS>
def data_collator(examples: list):
    len_ids = [len(example["input_ids"]) for example in examples]
    longest = max(len_ids) #之后按照batch中最长的input_ids进行padding
    
    input_ids = []
    labels_list = []
    masks_list = []
    
    for length, example in sorted(zip(len_ids, examples), key=lambda x: -x[0]):
        ids = example["input_ids"]
        labs = example["labels"]
        masks = example['attention_mask']
        
        ids = [vocab['<PAD>']] * (longest - length)+ids 
        labs = [-100] * (longest - length)+labs
        masks = [0]*(longest - length)+masks
        
        input_ids.append(torch.LongTensor(ids))
        labels_list.append(torch.LongTensor(labs))
        masks_list.append(torch.LongTensor(masks))
          
    input_ids = torch.stack(input_ids)
    labels = torch.stack(labels_list)
    attention_mask = torch.stack(masks_list)
    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask":attention_mask
    }

# 数据加载器
dl_train = DataLoader(dataset=ds_train,
         batch_size=200,
         drop_last=True,
         shuffle=True,
         collate_fn = data_collator        
        )

dl_val = DataLoader(dataset=ds_val,
         batch_size=200,
         drop_last=True,
         shuffle=False,
         collate_fn = data_collator  
        )


for batch in dl_train:
    break 
batch 
{'input_ids': tensor([[ 1, 11,  6,  ...,  7, 11,  2],
         [ 0,  1,  6,  ...,  5,  4,  2],
         [ 0,  1,  7,  ...,  8,  8,  2],
         ...,
         [ 0,  0,  0,  ..., 10, 11,  2],
         [ 0,  0,  0,  ..., 12,  3,  2],
         [ 0,  0,  0,  ..., 11, 12,  2]]),
 'labels': tensor([[-100, -100, -100,  ...,    7,   11,    2],
         [-100, -100, -100,  ...,    5,    4,    2],
         [-100, -100, -100,  ...,    8,    8,    2],
         ...,
         [-100, -100, -100,  ...,   10,   11,    2],
         [-100, -100, -100,  ...,   12,    3,    2],
         [-100, -100, -100,  ...,   11,   12,    2]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [0, 1, 1,  ..., 1, 1, 1],
         [0, 1, 1,  ..., 1, 1, 1],
         ...,
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1]])}

둘째, 모델 정의

다음으로, 성을 쌓기 위해 빌딩 블록처럼 LLaMA 모델을 낮은 곳에서 높은 곳으로 만들 것입니다.

먼저 4가지 기본 구성 요소를 빌드합니다. 회전 위치 인코딩, 다중 헤드 어텐션, 피드포워드 네트워크, 계층 정규화입니다. 벽, 지붕, 문, 창과 같은 모듈은 가장 기본적인 빌딩 블록으로 구성됩니다.

그런 다음 이 4가지 기본 구성 요소를 사용하여 중간 제품인 디코딩 계층을 구축합니다. 기본 구성 요소로 방을 만드는 것과 같습니다.

그런 다음 LlamaModel의 전체 모델은 여러 개의 중간 디코딩 레이어를 쌓아서 조립합니다. 이는 여러 개의 방을 만들어 성의 주요 구조를 만드는 것과 같습니다.

마지막으로 LlamaModel을 기반으로 두 개의 서로 다른 출력 헤드를 설계했습니다. 하나는 언어 모델인 Head이고 텍스트 생성에 사용할 수 있는 LlamaForCausalLM을 얻었습니다.

다른 하나는 텍스트 분류에 사용할 수 있는 LlamaForSequenceClassification을 획득한 분류 헤드입니다.

성의 주요 구조가 완성된 것을 기준으로 두 가지 다른 장식 양식을 설계하는 것과 같으며, 하나는 상업 활동을 위한 일부 레크리에이션 시설을 설치하는 것이고, 다른 하나는 군사 활동을 위한 일부 무기를 설치하는 것입니다. .


1, 회전 위치 인코딩: RoPE( 회전 매트릭스를 사용하여 실현된 절대 위치 인코딩 , 상대 위치 인코딩의 효과를 얻을 수 있음)

2, 멀티 헤드 어텐션: LlamaAttention ( 서로 다른 토큰 간의 정보 융합 에 사용됨 )

3, 피드포워드 네트워크: LlamaMLP( 위치별 다중 헤드 어텐션 융합 후 정보의 고차원 매핑 변환 용 )

4, 레이어 정규화: LlamaRMSNorm( 입력을 안정화하는 데 사용되며, 이는 각 단어 벡터의 방향을 변경하지 않고 모듈러스 길이를 표준화하는 것과 같습니다. )


5, 라마 디코딩 레이어: LlamaDecoderLayer ( 정보 융합 및 정보 변환 기능을 동시에 가진 기본 구조 단위 )


6, 라마 디코더: LlamaModel( 여러 디코딩 레이어의 스태킹 )


7. Llama 언어 모델: LlamaForCausalLM(텍스트 생성에 사용할 수 있는 디코더와 언어 모델 헤드)

8. 라마 분류 모델: LlamaForSequenceClassification(디코더와 분류 헤드, 텍스트 분류에 사용할 수 있음)


import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings

from transformers.models.llama.configuration_llama  import LlamaConfig
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING,LLAMA_START_DOCSTRING

logger = logging.get_logger('llama')

config = LlamaConfig(
    vocab_size=len(vocab),
    hidden_size=512,
    intermediate_size=2752,
    num_hidden_layers=8,
    num_attention_heads=16,
    hidden_act='silu',
    max_position_embeddings=128,
    initializer_range=0.02,
    rms_norm_eps=1e-06,
    use_cache=True,
    pad_token_id=0,
    bos_token_id=1,
    eos_token_id=2,
    tie_word_embeddings=False
) 

1. 회전 위치 인코딩 RoPE

회전 위치 인코딩은 회전 행렬을 사용하여 RoPE라고 하는 위치 인코딩(Rotary Position Encoding)을 나타냅니다.

RoPE에 대한 세 가지 핵심 지식은 다음과 같습니다.

  • RoPE의 설계 아이디어는 절대 위치 인코딩을 사용하여 상대 위치 인코딩의 효과를 얻는 것입니다.

  • RoPE는 절대 위치 인코딩을 나타내는 회전 행렬을 사용하여 구현됩니다.

  • NTK 확장 방법을 사용하면 RoPE 가 짧은 텍스트를 학습 하고 긴 텍스트를 예측할 수 있습니다.

참조 문서:

" 남에게 배우는 로타리 포지션 코딩 " https://kexue.fm/archives/8265

"RoPE는 16진수 인코딩입니다" https://kexue.fm/archives/9675

(1) 절대위치 코딩과 상대위치 코딩

위치 코딩은 일반적으로 절대 위치 코딩과 상대 위치 코딩으로 나눌 수 있습니다.

절대 위치 부호화의 장점은 계산이 간단하고 효율적이라는 점이지만 단점은 상대 위치 부호화에 비해 전반적인 효과가 좋지 않다는 점이다.

상대 위치 부호화의 장점은 효과가 더 좋다는 것이지만, 절대 위치 부호화에 비해 계산 효율이 좋지 않다는 단점이 있다.

절대 위치 인코딩:

상대 위치 인코딩:

상대적 위치 부호화에서 어텐션 가중치의 결과는 어텐션 계산에 포함된 토큰 벡터의 상대적 위치와만 관련되며 절대 위치와는 직접적인 관련이 없습니다.

이는 NLP 필드가 시퀀스 길이 방향으로 변환 불변성을 가지므로 일반적으로 상대 위치 인코딩이 절대 위치 인코딩보다 낫다는 사실과 일치합니다.

절대위치 부호화는 초기화 과정에서 시퀀스의 각 위치에 위치코드(숫자는 시퀀스 길이에 비례)를 할당하기만 하면 되고, 이후의 개입은 필요하지 않습니다.

그러나 상대 위치 부호화는 계산 과정에서 많은 상대 위치(수는 시퀀스 길이의 제곱에 비례함)를 얻어야 합니다.

따라서 절대 위치 코딩이 더 간단하고 효율적입니다.

(2) 회전 행렬을 사용하여 위치 인코딩을 나타냅니다.

위의 논의에서 알 수 있듯이 절대 위치 부호화와 상대 위치 부호화에는 장단점이 있는데 서로 배울 수 있는 방법은 없을까요?

예, 이 방법은 RoPE이며 설계 아이디어는 상대 위치 인코딩의 효과를 달성하기 위해 절대 위치 인코딩을 사용하는 것입니다.

그렇다면 회전 위치 인코딩은 상대 위치 인코딩의 효과를 얻기 위해 절대 위치 인코딩을 어떻게 사용합니까? 대답은 위치 인코딩을 나타내기 위해 회전 행렬을 사용하는 것입니다.

그중에는 속성을 만족하는 회전 행렬이 있습니다. 따라서 다음이 있습니다.

 상대 위치 인코딩 형식을 따릅니다.

우리는 절대 위치 인코딩을 사용하여 상대 위치 인코딩 효과를 얻습니다.

그렇다면 회전 행렬은 어떻게 생겼을까요?

2차원에서는 이렇게 보입니다.

NLP 분야에서 단어 벡터의 차원은 일반적으로 매우 높습니다(예: 4096).

행렬의 블록 아이디어를 사용하여 고차원의 경우 다음과 같은 형태로의 확장이 여전히 회전 행렬의 속성을 만족함을 증명할 수 있습니다.

그 중 즉 차원이 높을수록 삼각함수의 계수는 작아지고 주기는 커지며 변화는 느려진다.

회전 행렬은 희소 행렬이기 때문에 곱셈 계산을 직접 사용하는 것은 컴퓨팅 파워의 낭비가 될 것입니다 회전 위치 부호화 과정은 행렬 곱셈 연산에서 두 벡터의 Hadamard 곱의 합으로 단순화될 수 있습니다.

(3) 회전 위치 인코딩의 길이 확장

LLM을 적용함에 있어서 LLM이 지원하는 context length(max context length)라는 매우 중요한 파라미터가 있다.

맥락 길이가 길수록 더 많은 대화를 할 수 있고, 더 긴 논문의 요약 분석을 수행할 수 있으며, 또한 더 긴 기사를 생성할 수 있습니다.

그러나 LLM을 교육할 때 대부분의 교육 코퍼스는 충분히 길지 않습니다.많은 LLM 교육을 위해 설계된 최대 텍스트 길이는 2k로 가장 긴 2048 토큰입니다.

그렇다면 학습 중에 짧은 텍스트를 사용하고 추론 중에 긴 텍스트로 확장할 수 있습니까?

RoPE의 길이를 연장할 수 있습니다.

3가지 확장 시나리오를 소개합니다.

첫 번째는 직접 외삽입니다. 직접 외삽은 실제로 수정 없이 기존 위치 인코딩 공식을 계속 사용합니다.

2k에서 2.5k로 확장하는 것과 같이 확장 길이가 너무 길지 않은 경우 이 방법은 성능에 거의 영향을 미치지 않을 수 있습니다.

회전 위치 코드는 상대 위치 mn의 크기 에만 관련되기 때문에 일반적으로 장거리 감쇠가 있습니다. 즉, 상대적 거리가 더 큰 두 토큰 간의 상관 관계는 일반적으로 약합니다.

따라서 우리 모델이 훈련 데이터로부터 0-2k 사이의 상대적인 거리에 대한 토큰 간의 상관 관계에 대한 적절한 감쇠 법칙을 학습했다면, 이 법칙을 0-2.5k에 적용하는 것은 그리 큰 문제가 아니라고 생각할 수 있습니다.

그러나 2k에서 32k로 확장하는 것과 같이 더 긴 길이로 확장하려는 경우 이러한 직접 외삽 방식은 일반적으로 성능에 심각한 영향을 미칩니다. 우리가 배운 감쇠 법칙은 완전히 감쇠되어 5k에서 0으로 잘릴 수 있으므로 상대 거리가 5k보다 긴 두 토큰 간의 상호 작용을 캡처할 수 없으며 외삽은 성능 저하로 이어질 것입니다.

요약하자면 장거리 감쇠 법칙에 직접 외삽법을 사용하면 성능 저하로 이어지는 문제가 발생하기 쉽습니다.

길이 외삽이 성능에 미치는 영향을 줄이기 위해 더 긴 컨텍스트에서 몇 단계로 훈련된 모델을 미세 조정할 수 있습니다.

두 번째는 선형 보간입니다. 선형 보간은 위치 인코딩 공식을 변경해야 하며 이는 위치 일련 번호를 비례적으로 줄이는 것과 같습니다.

인코딩 공식은 와 같이 변경되는데, 2k에서 32k로 확장할 때 위치 번호를 원본의 1/16으로 변경하는 것과 같습니다.

선형 보간은 모델이 학습한 감쇠 법칙의 적용 범위를 변경하지 않으며, 그 효과는 일반적으로 미세 조정을 고려하지 않은 직접 외삽 방식보다 낫습니다.

그러나 2k에서 32k로 확장하는 것과 같이 확장 배율이 매우 크면 성능에 상당한 영향을 미칩니다.

이 경우 짧은 거리의 경우 감쇠법의 사용에 심각한 영향을 미치게 되므로 거리가 1인 두 개의 토큰은 길이를 늘린 후 1/16의 거리에 해당하며 감쇠법은 거리가 매우 큰 변화율을 가질 수 있으므로 상관 관계 평가가 합리적인 값에서 크게 벗어날 수 있습니다.

긴 텍스트에 대해 몇 단계를 미세 조정하면 선형 보간이 적용될 때 성능이 크게 향상될 수도 있습니다.

세 번째는 NTK 확장 방식으로 외삽법과 보간법의 장점을 결합한 방식으로 길이 확장 후 미세 조정 없이도 좋은 성능을 유지할 수 있다.

이전 분석에서 우리는 장거리 상황에서 감쇠 법칙에 직접 외삽을 사용하는 것이 문제가 발생하기 쉽고 근거리 상황에서 사용하는 데 영향을 미치지 않는다는 것을 알고 있습니다.

그러나 선형보간법은 단거리의 경우 감쇠법칙을 사용함에 있어 문제가 발생하기 쉽고, 장거리의 경우에는 거의 영향을 미치지 못한다.

근거리 경우의 외삽 속성(기본적으로 확장 전과 동일)과 장거리 경우의 보간 속성(확장 전 범위로 조정됨)을 결합하여 장거리 경우와 사용 짧은 거리에서의 감쇠 법칙의 영향을 많이 받지 않습니다.

RoPE 위치 코드의 첫 번째 줄에서 요소 계산 공식을 관찰하고 값이 클수록 삼각 함수에 해당하는 각 주파수 계수가 작거나 주파수가 낮을수록 해당 삼각 함수가 느려지는 것을 알 수 있습니다. 변화.

다음과 같은 직관적인 결론을 쉽게 얻을 수 있습니다: 짧은 거리의 차이(예: 1과 5의 차이)는 주로 고주파 성분(i가 상대적으로 작음)에 반영되고, 장거리의 차이(예: 차이) 5000~10000 사이), 주로 저주파 성분에 반영됩니다(i가 상대적으로 큼).

단거리의 경우 보간 특성을, 장거리의 경우 보간 특성을 가지기 위해 가장 높은 주파수()에서 값이 1이 되도록 위치 번호와 관련된 스케일링 계수를 설계할 수 있습니다(기존과 동일). 확장은 기본적으로 동일), 가장 낮은 주파수에서()는 줌 계수의 역수입니다(확장 전 범위로 확대됨).

유효한 옵션은 의 지수함수이며 그 효과는 의 스케일링과 동일하며 경계 조건에 따라 적절한 스케일링 계수를 쉽게 얻을 수 있습니다.

NTK 확장 방법의 요점은 고주파 외삽과 저주파 보간이며 구현 방법은 기본 인코딩 변환 과 유사하게 기본 숫자를 직접 스케일링하는 것입니다 .

NTK를 사용하여 긴 텍스트로 확장하면 미세 조정 없이도 성능이 약간만 저하됩니다.

다음은 RoPE 및 세 가지 길이 확장 방법의 구현입니다.

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False) #persistent=False将不会作为state_dict

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        #超过预设的max_position_embeddings则重新计算更大的Rope缓存,否则直接在缓存上切片
        if seq_len > self.max_seq_len_cached: 
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

    
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        t = t / self.scaling_factor #线性内插相当于将位置序号等比例缩小

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)


class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len

        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))  #NTK扩展方式直接对base进行缩放
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
            self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        
        #此处处理逻辑与原始的ROPE有差异,原始逻辑如下
        #emb = torch.cat((freqs, freqs), dim=-1)
        #emb[...,0::2]=freqs
        #emb[...,1::2]=freqs
        
        
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
        
        
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    
    #此处逻辑与原始的ROPE有所差异,原始逻辑如下
    #x1 = x[..., 0::2] 
    #x2 = x[..., 1::2]
    #res = torch.cat((x1, x2), dim=-1)
    #res[...,0::2]=-x2
    #res[...,1::2]=x1
    #return res
    
    x1 = x[..., : x.shape[-1] // 2] 
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

x = torch.randn(1,8,4,2)
rope = LlamaRotaryEmbedding(dim=8)
cos,sin = rope.forward(x,seq_len=4)
print(cos.shape) 
print(cos)
torch.Size([1, 1, 4, 8])
tensor([[[[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
            1.0000],
          [ 0.5403,  0.9950,  0.9999,  1.0000,  0.5403,  0.9950,  0.9999,
            1.0000],
          [-0.4161,  0.9801,  0.9998,  1.0000, -0.4161,  0.9801,  0.9998,
            1.0000],
          [-0.9900,  0.9553,  0.9996,  1.0000, -0.9900,  0.9553,  0.9996,
            1.0000]]]])

2. 멀티 헤드 어텐션 LlamaAttention

여기서 LlamaAttention은 기본적으로 "Attention Is All You Need" 논문과 동일합니다. 주요 차이점은 다음과 같습니다.

1. k와 v의 헤드 수는 q의 헤드 수의 분수일 수 있습니다. 그룹 컨벌루션의 아이디어와 유사하게 매개변수 스케일을 줄일 수 있습니다.

2. 원지 입력 시 한 번이 아닌 멀티 헤드 어텐션이 이루어질 때마다 로프 위치 인코딩을 수행한다.

3. 들어오는 키 및 값의 상태에 대한 past_key_value를 캐시할 수 있으므로 여러 라운드의 대화에서 반복 계산을 줄이고 가속 효과를 얻을 수 있습니다.

4. 어텐션_마스크는 추가를 통해 softmax 이전 어텐션 매트릭스에 적용됩니다.

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self._init_rope()

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value
    
    

3. 피드포워드 네트워크 LlamaMLP

피드포워드 신경망은 2계층 퍼셉트론 MLP입니다.

먼저 hidden_size 차원 up_proj에서 medium_size 차원으로 이동한 다음 down_proj가 hidden_size 차원으로 복원합니다.

여기서 주요 기능은 gated Attention 기능을 달성하기 위해 활성화 기능이 있는 gate_proj를 도입한 것입니다.

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        if self.config.pretraining_tp > 1:
            slice = self.intermediate_size // self.config.pretraining_tp
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)

            gate_proj = torch.cat(
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
            )
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
            down_proj = [
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
            ]
            down_proj = sum(down_proj)
        else:
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj

4. 계층 정규화 LlamaRMSNorm

여기에서 계층 정규화를 RMSNorm이라고 하며 표준 LayerNorm과 약간 다릅니다.

첫 번째는 평균을 제거하지 않고 직접 나눈 다음 바이어스를 추가하지 않은 RootMeanSquare입니다.

이 두 가지 작은 수정은 레이어의 정규화가 hidden_states에 해당하는 단어 벡터의 방향을 변경하지 않고 계수 길이만 변경하도록 보장할 수 있습니다.

어떤 의미에서 합리적이다.

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
    

5, 라마 디코딩 레이어

디코딩 레이어 LlamaDecoderLayer는 LlamaAttention, LlamaMLP, 2개의 LlamaRMSNorm으로 구성되어 있으며, 잔여 구조를 2번 사용합니다.

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = LlamaAttention(config=config)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs

6. 라마 디코더

LlamaModel은 여러 Llama 디코딩 레이어로 쌓입니다.

이해해야 할 몇 가지 핵심 사항이 있습니다.

1. _make_causal_mask언어 모델의 단방향 주의를 실현하기 위해 하부 삼각형의 마스크 구조를 구성하는 데 사용됩니다.

2. _expand_mask특수 기호와 관련된 들어오는 마스크 정보를 어텐션 매트릭스와 동일한 텐서 구조로 확장하는 데 사용됩니다.

3. gradient_checkpointing=True로 설정하면 비디오 메모리를 절약할 수 있습니다. 주로 torch.utils.checkpoint.checkpoint 메소드를 사용합니다. 그 원리는 매우 간단하여 decoder_layer가 전달되면 비디오 메모리를 절약하기 위해 중간 활성화 값이 저장되지 않고, 역방향이 다시 계산되면 해당 값이 다시 계산되어 시간과 공간을 교환합니다.

4. Gradient_checkpointing과 use_cache는 동시에 True로 설정할 수 없으며, 전자는 공간에 대한 비디오 메모리 시간을 절약하고 후자는 시간에 대한 시간과 공간을 절약합니다.

# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, 
    device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)


# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
    config_class = LlamaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["LlamaDecoderLayer"]
    _skip_keys_device_placement = "past_key_values"

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, LlamaModel):
            module.gradient_checkpointing = value


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                inputs_embeds.device
            )
            combined_attention_mask = (
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
            )

        return combined_attention_mask

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        seq_length_with_past = seq_length
        past_key_values_length = 0

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )

        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, output_attentions, None)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    position_ids,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

7. 라마 언어 모델

Llama 언어 모델 LlamaForCausalLM은 lm_head를 생성기로 추가하여 Llama 디코더 LlamaModel을 기반으로 합니다.

따라서 완전한 언어 모델이 실현됩니다.

또한 Llama 언어 모델은 다음과 같은 중요한 기능도 구현합니다.

1. 손실 계산 기능. 레이블이 전달 방법으로 전달되면 언어 모델의 교차 엔트로피 손실이 자동으로 계산됩니다. 레이블의 -100은 무시되며 계산에 포함되지 않습니다.

2. 텍스트 생성 생성 방법. 이 방법은 PreTrainedModel에서 상속받은 것으로 model.generation_config.num_beams를 설정하여 빔 탐색의 빔 폭을 선택할 수 있습니다.기본값은 1이며 욕심 많은 탐색입니다.

_CONFIG_FOR_DOC = "LlamaConfig"

class LlamaForCausalLM(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -1].unsqueeze(-1)

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past

8. 라마 분류 모델

LlamaForSequenceClassification은 시퀀스 분류 모델입니다.

이 분류 모델은 RLHF 프로세스에서 보상 모델을 훈련하는 데 사용할 수 있습니다.

@add_start_docstrings(
    """
    The LLaMa Model transformer with a sequence classification head on top (linear layer).

    [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    """,
    LLAMA_START_DOCSTRING,
)
class LlamaForSequenceClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = LlamaModel(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        logits = self.score(hidden_states)

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            sequence_lengths = -1
        else:
            if input_ids is not None:
                sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
                    logits.device
                )
            else:
                sequence_lengths = -1

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(pooled_logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(pooled_logits, labels)
        if not return_dict:
            output = (pooled_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

셋, 교육 모델

다음으로 LlamaForCausalLM을 훈련시켜 두 숫자를 합산하는 작업을 구현해 보겠습니다.

config = LlamaConfig(
    vocab_size=len(vocab),
    hidden_size=512,
    intermediate_size=2752,
    num_hidden_layers=8,
    num_attention_heads=16,
    num_key_value_heads=4,
    rope_scaling = None,
    hidden_act='silu',
    max_position_embeddings=128,
    initializer_range=0.02,
    rms_norm_eps=1e-06,
    use_cache=True,
    pad_token_id=0,
    bos_token_id=1,
    eos_token_id=2,
    tie_word_embeddings=False,
    pretraining_tp = 1,
    max_new_tokens = 100
) 

#试算一下
model = LlamaForCausalLM(config)
out = model.forward(**batch)
print(out.loss)

텐서(2.7630, grad_fn=)

from torchkeras import KerasModel 
from accelerate import Accelerator 

class StepRunner:
    def __init__(self, net, loss_fn, accelerator=None, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator if accelerator is not None else Accelerator() 
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
    
    def __call__(self, batch):
        
        #loss
        with self.accelerator.autocast():
            loss = self.net(**batch).loss

        #backward()
        if self.stage=="train" and self.optimizer is not None:        
            self.accelerator.backward(loss)
            if self.accelerator.sync_gradients:
                self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses (or plain metrics that can be averaged)
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        #metrics (stateful metrics)
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
KerasModel.StepRunner = StepRunner 


keras_model = KerasModel(model,loss_fn = None,
        optimizer=torch.optim.AdamW(model.parameters(),lr=3e-5))


#加载 之前训练过的权重
ckpt_path = 'llama_twosum'

keras_model.fit(train_data = dl_train,
                val_data = dl_val,
                epochs=100,patience=5,
                monitor='val_loss',mode='min',
                ckpt_path = ckpt_path,
                mixed_precision='fp16'
               )

그림

넷째, 모델 사용

from transformers.generation.utils import GenerationConfig
model.generation_config = GenerationConfig.from_dict({'num_beams':1,
                            'max_new_tokens':100,
                            'max_length':200})
model.generation_config.num_beams=1
model.generation_config.max_new_tokens = 100 
model.generation_config.max_length=200
def get_ans(tensor) ->"str":
    s = "".join([vocab_r[i] for i in tensor.tolist()])
    ans = s[s.find('=')+1:s.find('<EOS>')].replace('<BOS>','').replace('<EOS>','')
    return ans
x,y = get_data() 
print('x: '+''.join(x).replace('<BOS>',''))
print('y: '+''.join(y).replace('<EOS>',''))
x: 3481340050+90157504501803=
y: 90160985841853
input_ids = torch.tensor([[vocab[i] for i in x]]) 
out = model.generate(inputs=input_ids)
out 

텐서([[ 1, 5, 6, 10, 3, 5, 6, 12, 12, 7, 12, 13, 11, 12, 3, 7, 9, 7, 12, 6, 7, 12, 3, 10, 12, 5, 14, 11, 12, 3, 8, 12, 11, 10, 7, 10, 6, 3, 10, 7, 5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 12, 2, 2, 2, 2, 2, 2, 2, 2, 12, 3, 12, 3]])

get_ans(out[0])

'90160985841853'

V. 평가 모델

from tqdm import tqdm 
loop = tqdm(range(1,201))
correct = 0
for i in loop:
    x,y = get_data() 
    input_ids = torch.tensor([[vocab[i] for i in x]]) 
    out = model.generate(inputs=input_ids)
    pred = get_ans(out[0])
    gt = ''.join(y).replace('<EOS>','')
    if pred==gt:
        correct+=1
    loop.set_postfix(acc = correct/i)
    
print("acc=",correct/len(loop))

acc= 0.99

좋아, 우리 테스트는 99% 정확합니다!

추천

출처blog.csdn.net/sinat_37574187/article/details/132197855