Do LoRA Fine-tuning for ChatGLM-6B

ChatGLM-6B is a Chinese-English bilingual dialogue language model based on GLM (General Language Model). It has only 6.2 billion parameters, and the minimum after quantization (INT4 quantization) only needs 6GB of video memory, which can be deployed on consumer-grade graphics cards. After actually using this model for a period of time, we found that the model's dialogue performance ability is indeed very good. Then, it is very valuable to do Fine-tuning based on this model.

statement:

All technical information provided in this article is based on the historical version of THUDM/chatglm-6b
096f3de6b4959ce38bef7bb05f3129c931a3084e : .

Source address:

Build a dependent environment

Install the PyTorch environment:

pip install torch torchvision torchaudio

According to the official guide of ChatGLM-6B, install the software dependent environment:

pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels  

In order to do LoRA, you need to install peft

pip install peft

Load the model and Tokenizer

from transformers import AutoTokenizer, AutoModel

checkpoint = "THUDM/chatglm-6b"
revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"
model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)

Analysis Model Structure

After the model is loaded, we can print this modeland tokenizerto establish a basic understanding of the model.

First print model:

print(model)

The following results are obtained:

ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (word_embeddings): Embedding(150528, 4096)
    (layers): ModuleList(
      (0-27): 28 x GLMBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attention): SelfAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): GLU(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=4096, out_features=150528, bias=False)
)

Simply analyze the structure of this model, at least the following information can be obtained:

  • The model uses the Transformer structure, so LoRA can be used for Fine-tuning
  • As can be seen from the Word Embedding layer, the vocabulary size is150528
  • The goals that LoRA can operate on are:query_key_value

Then print tokenizer:

print(tokenizer)

The following results are obtained (for readability, the results have been divided into rows):

ChatGLMTokenizer(
	name_or_path='THUDM/chatglm-6b', 
	vocab_size=150344, 
	model_max_length=2048, 
	is_fast=False, 
	padding_side='left', 
	truncation_side='right', 
	special_tokens={
		'bos_token': '<sop>', 
		'eos_token': '</s>', 
		'unk_token': '<unk>', 
		'pad_token': '<pad>', 
		'mask_token': '[MASK]'
	}
)

Here are a few points to focus on:

  • The vocabulary size vocab_sizeis150344
  • is not a fast Tokenizer ( is_fastthe value is False)
  • Special tokens include: bos eos padandmask

Why is the vocabulary size in model 150528, but the vocabulary size defined tokenizerin is 150344? Readers can take this question to read the source code of the model project to see if they can find the answer.

Configure LoRA

With the peft library, we can easily inject LoRA into the model.

from peft import LoraConfig, get_peft_model, TaskType

def load_lora_config(model):
	config = LoraConfig(
	    task_type=TaskType.CAUSAL_LM, 
	    inference_mode=False,
	    r=8, 
	    lora_alpha=32, 
	    lora_dropout=0.1,
	    target_modules=["query_key_value"]
	)
	return get_peft_model(model, config)

model = load_lora_config(model)

Print the number of trainable parameters:

model.print_trainable_parameters()

The following results are obtained:

trainable params: 3670016 || all params: 6258876416 || trainable%: 0.05863697820615348

It can be seen that the total amount of parameters is 6,258,876,416, and the amount of trainable parameters is about3,670,016 . 0.0586%The amount of training parameters is only at the million level, which is quite friendly! Another thing to note is that ChatGLM-6B is a causal language model (Causal Language Model), so the task type we choose here is CAUSAL_LM.

build dataset

define constant

Before building, we define several special Token constants:

bos = tokenizer.bos_token_id
eop = tokenizer.eop_token_id
pad = tokenizer.pad_token_id
mask = tokenizer.mask_token_id
gmask = tokenizer.sp_tokenizer[tokenizer.gMASK_token]

Print out these values:

print("bos = ", bos)
print("eop = ", eop)
print("pad = ", pad)
print("mask = ", mask)
print("gmask = ", gmask)

The following results are obtained:

bos =  150004
eop =  150005
pad =  20003
mask =  150000
gmask =  150001

We can also directly replace the dynamically calculated part with this constant result. The result of constant modification becomes:

bos = 150004
eop = 150005
pad = 20003
mask = 150000
gmask = 150001

In addition to the Token constants defined above, we also need to define the device name bound to the model training, as well as the maximum input length and maximum output length, etc., as follows:

device = "cuda"
max_src_length = 200
max_dst_length = 500

Developers can determine these maximum lengths based on their graphics card performance and the characteristics of the data set to be processed.

Test the codec of Tokenizer

We can do a simple test first:

text = "AI探险家"
print(tokenizer.encode(text, add_special_tokens = True))
print(tokenizer.encode(text, add_special_tokens = False))

The output is:

[26738, 98715, 83920, 150001, 150004]
[26738, 98715, 83920]

From this result, it can be seen that the naked code of the words " AI explorer[26738, 98715, 83920] " is . Why is this so? We can decode each value again to see the output:

print(tokenizer.decode([26738]))
print(tokenizer.decode([98715]))
print(tokenizer.decode([83920]))

The output is:

AI
探险
家

Observing this result, the reader should be able to establish a basic understanding of the vocabulary. If readers are interested, they can also encode the words "A", "I", "Explorer" and "Danger" respectively to see what the encoding results are.

In addition, add_special_tokens = Truewhen , the encoding result will add 150001and at the end 150004, that is, gmaskand bos. Please note that our training data should be constructed according to the following encoding requirements:

[token, ..., token, gmask, bos, token, ... token, eop]

Therefore, the encoding of the first half of the text can be directly changed add_special_tokens = True, the encoding of the second half of the text can be changed add_special_tokens = False, and finally one is spliced eop.

Define Prompt

Our Fine-tuning task is a question answering task (QA for short), so a simple Prompt looks like this:

PROMPT_PATTERN = "问:{}\n答: "

{}Fill in the question text of the QA training set. In the case of limited video memory, if the long text is not restricted, it is easy to report errors CUDA out of memorylike this . When dealing with long text, when the upper limit of the encoded array is given, there may be several ways:

  • Truncate encodings that exceed the end
  • Truncate the preceding excess code
  • drop training samples

Each method has its own advantages and disadvantages, and developers can choose a processing method according to the characteristics of their own data. Of course, if your video memory is large enough, you don't have to deal with it. This paper deals with the first method mentioned above.
In order not PROMPT_PATTERNto \n答: truncate these words in , we PROMPT_PATTERNsplit into two parts:

PROMPT_PATTERN = "问:{}"
SEP_PATTERN = "\n答: "

Based on this Prompt template, we define the following three helper methods:

def create_prompt(question):
    return PROMPT_PATTERN.format(question), SEP_PATTERN


def create_prompt_ids(tokenizer, question, max_src_length):
    prompt, sep = create_prompt(question)
    sep_ids = tokenizer.encode(
        sep, 
        add_special_tokens = True
    )
    sep_len = len(sep_ids)
    special_tokens_num = 2
    prompt_ids = tokenizer.encode(
        prompt, 
        max_length = max_src_length - (sep_len - special_tokens_num),
        truncation = True,
        add_special_tokens = False
    )

    return prompt_ids + sep_ids


def create_inputs_and_labels(tokenizer, question, answer, device):
    prompt = create_prompt_ids(tokenizer, question, max_src_length)
    completion = tokenizer.encode(
        answer, 
        max_length = max_dst_length,
        truncation = True,
        add_special_tokens = False
    )

    inputs = prompt + completion + [eop]
    labels = [-100] * len(prompt) + completion + [eop] 
    
    inputs = torch.tensor(inputs, dtype=torch.long, device=device)
    labels = torch.tensor(labels, dtype=torch.long, device=device)
    return inputs, labels

Two points worth noting:

  • It can be seen from the implementation of create_prompt_idsthis function that SEP_PATTERNwhen we encode the separator, we automatically add the two special Tokens mentioned above.
  • create_inputs_and_labelsIn the function implementation of , we express thelabels part that does not need to be processed by the value . -100Because ChatGLMForConditionalGenerationinternally when calculating the loss function, it uses torch.nn.CrossEntropyLoss. One of the arguments to this function has a ignore_indexdefault value of -100. This allows us to calculate the loss function without considering the value of the non-label part.

Build Attention Mask and Position IDs

def get_attention_mask(tokenizer, input_ids, device):
    seq = input_ids.tolist()
    context_len = seq.index(bos)
    seq_len = len(seq)
    attention_mask = torch.ones((seq_len, seq_len), device=device)
    attention_mask.tril_()
    attention_mask[..., :context_len] = 1
    attention_mask.unsqueeze_(0)
    attention_mask = (attention_mask < 0.5).bool()
    return attention_mask


def get_position_ids(tokenizer, input_ids, device, position_encoding_2d=True):
    seq = input_ids.tolist()
    context_len = seq.index(bos)
    seq_len = len(seq)

    mask_token = mask if mask in seq else gmask
    use_gmask = False if mask in seq else gmask

    mask_position = seq.index(mask_token)

    if position_encoding_2d:
        position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
        if not use_gmask:
            position_ids[context_len:] = mask_position
        block_position_ids = torch.cat((
            torch.zeros(context_len, dtype=torch.long, device=device),
            torch.arange(seq_len - context_len, dtype=torch.long, device=device) + 1
        ))
        position_ids = torch.stack((position_ids, block_position_ids), dim=0)
    else:
        position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
        if not use_gmask:
            position_ids[context_len:] = mask_position
    
    return position_ids

In this general implementation, maskwe gmaskmake a distinction between the two cases of and , and also decide whether to perform position_encoding_2dcase-by processing. The QA task of this article adopts gmaskand uses position_encoding_2d = True.

We can construct the following question and answer to verify the output of these functions:

test_data = {
    
    
	"question": "AI探险家帅不帅?",
	"answer": "非常帅!"
}

inputs, labels = create_inputs_and_labels(tokenizer, **test_data, device=device)
attention_mask = get_attention_mask(tokenizer, inputs, device=device)
position_ids = get_position_ids(tokenizer, inputs, device=device)

print("inputs: \n", inputs.tolist())
print("\nlabels: \n", labels.tolist())
print("\nposition_ids: \n", position_ids.tolist())
print("\nattention_mask: \n", attention_mask.tolist())

Output (formatted for readability):

inputs: 
 [20005, 84286, 20012, 31943, 98715, 83920, 87359, 83848, 87359, 20031, 20005, 20004, 87342, 20012, 150001, 150004, 20005, 84122, 87359, 20035, 150005]

labels: 
 [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 20005, 84122, 87359, 20035, 150005]

position_ids: 
 [
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  1,  2,  3,  4,  5]
 ]

attention_mask: 
 [[
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]]

Combined with the observation data of the paper, it is basically in line with expectations.

create dataset

We first define the training data with the following format:

train_data = [
	{
    
    "question": "问题1", "answer": "答案1"},
	{
    
    "question": "问题2", "answer": "答案2"},
]

After defining the format, we first create a QADatasetclass , as follows:

from torch.utils.data import Dataset

class QADataset(Dataset):
    def __init__(self, data, tokenizer) -> None:
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
 

    def __getitem__(self, index):
        item_data = self.data[index]
        tokenizer = self.tokenizer
        input_ids, labels = create_inputs_and_labels(
            tokenizer, 
            device=device,
            **item_data
        )
        
        attention_mask = get_attention_mask(tokenizer, input_ids, device)
        position_ids = get_position_ids(tokenizer, input_ids, device)

        return {
    
    
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
            "position_ids": position_ids
        }
        

    def __len__(self):
        return len(self.data)

Then create a Data Collator:

def collate_fn(batch):
    input_ids = []
    attention_mask = []
    labels = []
    position_ids = []
    
    for obj in batch:
        input_ids.append(obj['input_ids'])
        labels.append(obj['labels'])
        attention_mask.append(obj['attention_mask'])
        position_ids.append(obj['position_ids'])
        
    return {
    
    
        'input_ids': torch.stack(input_ids),
        'attention_mask': torch.stack(attention_mask), 
        'labels': torch.stack(labels),
        'position_ids':torch.stack(position_ids)
    }

start training

from transformers import TrainingArguments, Trainer
model.to(device)

training_args = TrainingArguments(
    "output",
    fp16 =True,
    save_steps = 500,
    save_total_limit = 3,
    gradient_accumulation_steps=1,
    per_device_train_batch_size = 1,
    learning_rate = 1e-4,
    max_steps=1500,
    logging_steps=50,
    remove_unused_columns=False,
    seed=0,
    data_seed=0,
    group_by_length=False,
    dataloader_pin_memory=False
)

class ModifiedTrainer(Trainer):

    def compute_loss(self, model, inputs, return_outputs=False):
        return model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            position_ids=inputs["position_ids"],
            labels=inputs["labels"],
        ).loss


train_dataset = QADataset(train_data, tokenizer=tokenizer)
trainer = ModifiedTrainer(
    model=model,
    train_dataset=train_dataset,
    args=training_args,
    data_collator=collate_fn,
    tokenizer=tokenizer
)

trainer.train()

predict

response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])
print(response)

save trained model

import os

def save_tuned_parameters(model, path):
    saved_params = {
    
    
        k: v.to(device)
        for k, v in model.named_parameters()
        if v.requires_grad
    }
    torch.save(saved_params, path)

save_tuned_parameters(model, os.path.join("/path/to/output", "chatglm-6b-lora.pt"))

Overload the trained model

checkpoint = "THUDM/chatglm-6b"
revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"
model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)

model = load_lora_config(model)
model.load_state_dict(torch.load(f"/path/to/output/chatglm-6b-lora.pt"), strict=False)

model.half().cuda().eval()
response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])
print(response)

Guess you like

Origin blog.csdn.net/phycoding/article/details/129884586