Do LoRA Fine-tuning for ChatGLM-6B
ChatGLM-6B is a Chinese-English bilingual dialogue language model based on GLM (General Language Model). It has only 6.2 billion parameters, and the minimum after quantization (INT4 quantization) only needs 6GB of video memory, which can be deployed on consumer-grade graphics cards. After actually using this model for a period of time, we found that the model's dialogue performance ability is indeed very good. Then, it is very valuable to do Fine-tuning based on this model.
statement:
All technical information provided in this article is based on the historical version of THUDM/chatglm-6b
096f3de6b4959ce38bef7bb05f3129c931a3084e
: .
Source address:
Build a dependent environment
Install the PyTorch environment:
pip install torch torchvision torchaudio
According to the official guide of ChatGLM-6B, install the software dependent environment:
pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels
In order to do LoRA, you need to install peft
pip install peft
Load the model and Tokenizer
from transformers import AutoTokenizer, AutoModel
checkpoint = "THUDM/chatglm-6b"
revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"
model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
Analysis Model Structure
After the model is loaded, we can print this model
and tokenizer
to establish a basic understanding of the model.
First print model
:
print(model)
The following results are obtained:
ChatGLMForConditionalGeneration(
(transformer): ChatGLMModel(
(word_embeddings): Embedding(150528, 4096)
(layers): ModuleList(
(0-27): 28 x GLMBlock(
(input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(attention): SelfAttention(
(rotary_emb): RotaryEmbedding()
(query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
(dense): Linear(in_features=4096, out_features=4096, bias=True)
)
(post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(mlp): GLU(
(dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
(dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
)
)
)
(final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=4096, out_features=150528, bias=False)
)
Simply analyze the structure of this model, at least the following information can be obtained:
- The model uses the Transformer structure, so LoRA can be used for Fine-tuning
- As can be seen from the Word Embedding layer, the vocabulary size is
150528
- The goals that LoRA can operate on are:
query_key_value
Then print tokenizer
:
print(tokenizer)
The following results are obtained (for readability, the results have been divided into rows):
ChatGLMTokenizer(
name_or_path='THUDM/chatglm-6b',
vocab_size=150344,
model_max_length=2048,
is_fast=False,
padding_side='left',
truncation_side='right',
special_tokens={
'bos_token': '<sop>',
'eos_token': '</s>',
'unk_token': '<unk>',
'pad_token': '<pad>',
'mask_token': '[MASK]'
}
)
Here are a few points to focus on:
- The vocabulary size
vocab_size
is150344
- is not a fast Tokenizer (
is_fast
the value isFalse
) - Special tokens include:
bos
eos
pad
andmask
Why is the vocabulary size in model 150528
, but the vocabulary size defined tokenizer
in is 150344
? Readers can take this question to read the source code of the model project to see if they can find the answer.
Configure LoRA
With the peft library, we can easily inject LoRA into the model.
from peft import LoraConfig, get_peft_model, TaskType
def load_lora_config(model):
config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["query_key_value"]
)
return get_peft_model(model, config)
model = load_lora_config(model)
Print the number of trainable parameters:
model.print_trainable_parameters()
The following results are obtained:
trainable params: 3670016 || all params: 6258876416 || trainable%: 0.05863697820615348
It can be seen that the total amount of parameters is 6,258,876,416
, and the amount of trainable parameters is about3,670,016
. 0.0586%
The amount of training parameters is only at the million level, which is quite friendly! Another thing to note is that ChatGLM-6B is a causal language model (Causal Language Model), so the task type we choose here is CAUSAL_LM
.
build dataset
define constant
Before building, we define several special Token constants:
bos = tokenizer.bos_token_id
eop = tokenizer.eop_token_id
pad = tokenizer.pad_token_id
mask = tokenizer.mask_token_id
gmask = tokenizer.sp_tokenizer[tokenizer.gMASK_token]
Print out these values:
print("bos = ", bos)
print("eop = ", eop)
print("pad = ", pad)
print("mask = ", mask)
print("gmask = ", gmask)
The following results are obtained:
bos = 150004
eop = 150005
pad = 20003
mask = 150000
gmask = 150001
We can also directly replace the dynamically calculated part with this constant result. The result of constant modification becomes:
bos = 150004
eop = 150005
pad = 20003
mask = 150000
gmask = 150001
In addition to the Token constants defined above, we also need to define the device name bound to the model training, as well as the maximum input length and maximum output length, etc., as follows:
device = "cuda"
max_src_length = 200
max_dst_length = 500
Developers can determine these maximum lengths based on their graphics card performance and the characteristics of the data set to be processed.
Test the codec of Tokenizer
We can do a simple test first:
text = "AI探险家"
print(tokenizer.encode(text, add_special_tokens = True))
print(tokenizer.encode(text, add_special_tokens = False))
The output is:
[26738, 98715, 83920, 150001, 150004]
[26738, 98715, 83920]
From this result, it can be seen that the naked code of the words " AI explorer[26738, 98715, 83920]
" is . Why is this so? We can decode each value again to see the output:
print(tokenizer.decode([26738]))
print(tokenizer.decode([98715]))
print(tokenizer.decode([83920]))
The output is:
AI
探险
家
Observing this result, the reader should be able to establish a basic understanding of the vocabulary. If readers are interested, they can also encode the words "A", "I", "Explorer" and "Danger" respectively to see what the encoding results are.
In addition, add_special_tokens = True
when , the encoding result will add 150001
and at the end 150004
, that is, gmask
and bos
. Please note that our training data should be constructed according to the following encoding requirements:
[token, ..., token, gmask, bos, token, ... token, eop]
Therefore, the encoding of the first half of the text can be directly changed add_special_tokens = True
, the encoding of the second half of the text can be changed add_special_tokens = False
, and finally one is spliced eop
.
Define Prompt
Our Fine-tuning task is a question answering task (QA for short), so a simple Prompt looks like this:
PROMPT_PATTERN = "问:{}\n答: "
{}
Fill in the question text of the QA training set. In the case of limited video memory, if the long text is not restricted, it is easy to report errors CUDA out of memory
like this . When dealing with long text, when the upper limit of the encoded array is given, there may be several ways:
- Truncate encodings that exceed the end
- Truncate the preceding excess code
- drop training samples
Each method has its own advantages and disadvantages, and developers can choose a processing method according to the characteristics of their own data. Of course, if your video memory is large enough, you don't have to deal with it. This paper deals with the first method mentioned above.
In order not PROMPT_PATTERN
to \n答:
truncate these words in , we PROMPT_PATTERN
split into two parts:
PROMPT_PATTERN = "问:{}"
SEP_PATTERN = "\n答: "
Based on this Prompt template, we define the following three helper methods:
def create_prompt(question):
return PROMPT_PATTERN.format(question), SEP_PATTERN
def create_prompt_ids(tokenizer, question, max_src_length):
prompt, sep = create_prompt(question)
sep_ids = tokenizer.encode(
sep,
add_special_tokens = True
)
sep_len = len(sep_ids)
special_tokens_num = 2
prompt_ids = tokenizer.encode(
prompt,
max_length = max_src_length - (sep_len - special_tokens_num),
truncation = True,
add_special_tokens = False
)
return prompt_ids + sep_ids
def create_inputs_and_labels(tokenizer, question, answer, device):
prompt = create_prompt_ids(tokenizer, question, max_src_length)
completion = tokenizer.encode(
answer,
max_length = max_dst_length,
truncation = True,
add_special_tokens = False
)
inputs = prompt + completion + [eop]
labels = [-100] * len(prompt) + completion + [eop]
inputs = torch.tensor(inputs, dtype=torch.long, device=device)
labels = torch.tensor(labels, dtype=torch.long, device=device)
return inputs, labels
Two points worth noting:
- It can be seen from the implementation of
create_prompt_ids
this function thatSEP_PATTERN
when we encode the separator, we automatically add the two special Tokens mentioned above. create_inputs_and_labels
In the function implementation of , we express thelabels
part that does not need to be processed by the value .-100
BecauseChatGLMForConditionalGeneration
internally when calculating the loss function, it usestorch.nn.CrossEntropyLoss
. One of the arguments to this function has aignore_index
default value of-100
. This allows us to calculate the loss function without considering the value of the non-label part.
Build Attention Mask and Position IDs
def get_attention_mask(tokenizer, input_ids, device):
seq = input_ids.tolist()
context_len = seq.index(bos)
seq_len = len(seq)
attention_mask = torch.ones((seq_len, seq_len), device=device)
attention_mask.tril_()
attention_mask[..., :context_len] = 1
attention_mask.unsqueeze_(0)
attention_mask = (attention_mask < 0.5).bool()
return attention_mask
def get_position_ids(tokenizer, input_ids, device, position_encoding_2d=True):
seq = input_ids.tolist()
context_len = seq.index(bos)
seq_len = len(seq)
mask_token = mask if mask in seq else gmask
use_gmask = False if mask in seq else gmask
mask_position = seq.index(mask_token)
if position_encoding_2d:
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
if not use_gmask:
position_ids[context_len:] = mask_position
block_position_ids = torch.cat((
torch.zeros(context_len, dtype=torch.long, device=device),
torch.arange(seq_len - context_len, dtype=torch.long, device=device) + 1
))
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
else:
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
if not use_gmask:
position_ids[context_len:] = mask_position
return position_ids
In this general implementation, mask
we gmask
make a distinction between the two cases of and , and also decide whether to perform position_encoding_2d
case-by processing. The QA task of this article adopts gmask
and uses position_encoding_2d = True
.
We can construct the following question and answer to verify the output of these functions:
test_data = {
"question": "AI探险家帅不帅?",
"answer": "非常帅!"
}
inputs, labels = create_inputs_and_labels(tokenizer, **test_data, device=device)
attention_mask = get_attention_mask(tokenizer, inputs, device=device)
position_ids = get_position_ids(tokenizer, inputs, device=device)
print("inputs: \n", inputs.tolist())
print("\nlabels: \n", labels.tolist())
print("\nposition_ids: \n", position_ids.tolist())
print("\nattention_mask: \n", attention_mask.tolist())
Output (formatted for readability):
inputs:
[20005, 84286, 20012, 31943, 98715, 83920, 87359, 83848, 87359, 20031, 20005, 20004, 87342, 20012, 150001, 150004, 20005, 84122, 87359, 20035, 150005]
labels:
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 20005, 84122, 87359, 20035, 150005]
position_ids:
[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5]
]
attention_mask:
[[
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]]
Combined with the observation data of the paper, it is basically in line with expectations.
create dataset
We first define the training data with the following format:
train_data = [
{
"question": "问题1", "answer": "答案1"},
{
"question": "问题2", "answer": "答案2"},
]
After defining the format, we first create a QADataset
class , as follows:
from torch.utils.data import Dataset
class QADataset(Dataset):
def __init__(self, data, tokenizer) -> None:
super().__init__()
self.data = data
self.tokenizer = tokenizer
def __getitem__(self, index):
item_data = self.data[index]
tokenizer = self.tokenizer
input_ids, labels = create_inputs_and_labels(
tokenizer,
device=device,
**item_data
)
attention_mask = get_attention_mask(tokenizer, input_ids, device)
position_ids = get_position_ids(tokenizer, input_ids, device)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids
}
def __len__(self):
return len(self.data)
Then create a Data Collator:
def collate_fn(batch):
input_ids = []
attention_mask = []
labels = []
position_ids = []
for obj in batch:
input_ids.append(obj['input_ids'])
labels.append(obj['labels'])
attention_mask.append(obj['attention_mask'])
position_ids.append(obj['position_ids'])
return {
'input_ids': torch.stack(input_ids),
'attention_mask': torch.stack(attention_mask),
'labels': torch.stack(labels),
'position_ids':torch.stack(position_ids)
}
start training
from transformers import TrainingArguments, Trainer
model.to(device)
training_args = TrainingArguments(
"output",
fp16 =True,
save_steps = 500,
save_total_limit = 3,
gradient_accumulation_steps=1,
per_device_train_batch_size = 1,
learning_rate = 1e-4,
max_steps=1500,
logging_steps=50,
remove_unused_columns=False,
seed=0,
data_seed=0,
group_by_length=False,
dataloader_pin_memory=False
)
class ModifiedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
return model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
position_ids=inputs["position_ids"],
labels=inputs["labels"],
).loss
train_dataset = QADataset(train_data, tokenizer=tokenizer)
trainer = ModifiedTrainer(
model=model,
train_dataset=train_dataset,
args=training_args,
data_collator=collate_fn,
tokenizer=tokenizer
)
trainer.train()
predict
response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])
print(response)
save trained model
import os
def save_tuned_parameters(model, path):
saved_params = {
k: v.to(device)
for k, v in model.named_parameters()
if v.requires_grad
}
torch.save(saved_params, path)
save_tuned_parameters(model, os.path.join("/path/to/output", "chatglm-6b-lora.pt"))
Overload the trained model
checkpoint = "THUDM/chatglm-6b"
revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"
model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
model = load_lora_config(model)
model.load_state_dict(torch.load(f"/path/to/output/chatglm-6b-lora.pt"), strict=False)
model.half().cuda().eval()
response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])
print(response)