[NLP] Alpaca-Lora を使用したラマ モデルに基づく微調整チュートリアル

Stanford Alpaca は LLaMA モデル全体で微調整されます。つまり、事前トレーニングされたモデルのすべてのパラメーターが微調整されます (完全微調整)。ただし、この方法でも依然としてハードウェアのコストが高く、トレーニング効率が低くなります。

[NLP] 大規模な言語モデルを理解するための効率的な微調整 (PEFT)

したがって、Alpaca-Lora は Lora テクノロジーを使用して、元のモデルの LLaMA パラメーターを凍結しながらモデルにネットワーク層を追加し、これらの新しいネットワーク層パラメーターのみをトレーニングします。これらの新しいパラメータの数が少ないため、微調整のコストが大幅に削減されるだけでなく、完全な微調整と同様の効果が得られます。

LoRA の原理は実際には複雑ではありません。その中心的なアイデアは、元の事前トレーニングされた言語モデルの隣にバイパスを追加し、次元削減とその後の次元増加操作を実行して、いわゆる固有ランク (事前トレーニング) をシミュレートすることです。一般化のプロセスは、実際には、さまざまなタスクの共通の低次元固有 (低次元固有) 部分空間内の非常に少数の自由パラメーターを最適化することです。トレーニング中、事前トレーニングされた言語モデルのパラメーターは固定され、次元削減行列 A と次元強化行列 B のみがトレーニングされます。モデルの入力次元と出力次元は変更されませんが、出力時に BA のパラメータと事前学習済み言語モデルが重ね合わされます。A をランダムなガウス分布で初期化し、B を 0 行列で初期化します。これにより、トレーニングの開始時に、新しく追加されたパス BA=0 がモデルの結果に影響を与えないことが保証されます。

推論中は、左側と右側の部分の結果を加算するだけです。 h=Wx+BAx=(W+BA)x なので、学習後の行列積 BA と元の重み行列 W を新しい値として加算するだけです。元の事前トレーニング済み言語モデルの W を重みパラメーターに置き換えるだけで十分であり、追加のコンピューティング リソースは追加されません。

LoRA の最大の利点は、高速でメモリ使用量が少ないため、消費者向けハードウェアで実行できることです。

ワン環境構築

基本的な環境構成は以下のとおりです。

  • オペレーティングシステム:  CentOS 7
  • CPU: 単一ノードには 1TB メモリを搭載した Intel CPU が搭載されており、物理 CPU の数は 64 個、CPU あたりのコア数は 16 個です。
  • GPU:  4 カード A100 80GB GPU
  • Docker イメージ: pytorch:1.13.0-cuda11.6-cudnn8-devel

Alpaca-LoRA プロジェクトでは、著者らは、安価で効率的な微調整のために、Hugging Face の PEFT を使用したと述べました。PEFT はライブラリ (Prefix Tuning、P-Tuning、Prompt Tuning に加えて LoRA もサポートされているテクノロジの 1 つ) であり、これを使用すると、さまざまな Transformer ベースの言語モデルを使用して効率的な微調整を行うことができます。以下のPEFTをインストールします。

 
 
#安装peft
git clone https://github.com/huggingface/peft.git
cd peft/
pip install .

#安装bitsandbytes。
git clone [email protected]:TimDettmers/bitsandbytes.git
cd bitsandbytes
CUDA_VERSION=116 make cuda11x
python setup.py install
bitsandbytes のインストール時に次のエラーが発生した場合: 
/usr/bin/ld: -lcudart が見つかりません

次のコマンドを実行してください

cd /usr/lib
ln -s /usr/local/cuda/lib64/libcudart.so libcudart.so

#下载alpaca-lora
git clone [email protected]:tloen/alpaca-lora.git
cd alpaca-lora
pip install -r requirements.txt

requirements.txtファイルの具体的な内容は次のとおりです。

accelerate
appdirs
loralib
bitsandbytes
black
black[jupyter]
datasets
fire
git+https://github.com/huggingface/peft.git
transformers>=4.28.0
sentencepiece
gradio

モデル形式の変換

LLaMA オリジナルのウェイト ファイルを、Transformers ライブラリに対応するモデル ファイル形式に変換します。変換されたモデルは、次のように Hugging Face から直接ダウンロードできます。

ダウンロード方法は、【NLP】ハグフェイスモデル・データファイルのダウンロード方法をご参照ください。

decapoda-research/llama-7b-hf · 抱き合う顔

decapoda-research/llama-13b-hf · 抱きしめる顔

モデルの微調整

python finetune.py \
    --base_model '/disk1/llama-13b' \
    --data_path './alpaca_data_cleaned_archive.json' \
    --output_dir './lora-alpaca' \
    --batch_size 128 \
    --micro_batch_size 8 \
    --num_epochs 1


torchrun --nproc_per_node=4 --master_port=29000 finetune.py \
    --base_model '/disk1/llama-13b' \
    --data_path './alpaca_data_cleaned_archive.json' \
    --output_dir './lora-alpaca' \
    --batch_size 128 \
    --micro_batch_size 8 \
    --num_epochs 1
Training Alpaca-LoRA model with params:
base_model: /disk1/llama-13b
data_path: ./alpaca_data_cleaned_archive.json
output_dir: ./lora-alpaca
batch_size: 128
micro_batch_size: 8
num_epochs: 1
learning_rate: 0.0003
cutoff_len: 256
val_set_size: 2000
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: ['q_proj', 'v_proj']
train_on_inputs: True
add_eos_token: False
group_by_length: False
wandb_project: 
wandb_run_name: 
wandb_watch: 
wandb_log_model: 
resume_from_checkpoint: False
prompt template: alpaca

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:43<00:00,  1.06s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:43<00:00,  1.06s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:43<00:00,  1.06s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:43<00:00,  1.06s/it]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
/opt/conda/lib/python3.9/site-packages/peft/utils/other.py:102: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
  warnings.warn(
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
/opt/conda/lib/python3.9/site-packages/peft/utils/other.py:102: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
  warnings.warn(
/opt/conda/lib/python3.9/site-packages/peft/utils/other.py:102: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
  warnings.warn(
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
/opt/conda/lib/python3.9/site-packages/peft/utils/other.py:102: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
  warnings.warn(
trainable params: 6,553,600 || all params: 13,022,417,920 || trainable%: 0.05032552357220002
Map:   3%|███▊                                                                                                                                          | 1330/49759 [00:01<00:39, 1216.23 examples/s]trainable params: 6,553,600 || all params: 13,022,417,920 || trainable%: 0.05032552357220002
Map:   0%|                                                                                                                                                           | 0/49759 [00:00<?, ? examples/s]trainable params: 6,553,600 || all params: 13,022,417,920 || trainable%: 0.05032552357220002
Map:   1%|▊                                                                                                                                              | 272/49759 [00:00<00:36, 1350.21 examples/s]trainable params: 6,553,600 || all params: 13,022,417,920 || trainable%: 0.05032552357220002
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49759/49759 [00:38<00:00, 1294.31 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49759/49759 [00:38<00:00, 1284.04 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49759/49759 [00:38<00:00, 1283.95 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1221.03 examples/s]
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49759/49759 [00:39<00:00, 1274.42 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1285.16 examples/s]
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1281.27 examples/s]
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1290.31 examples/s]
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
  0%|                                                                                                                                                                         | 0/388 [00:00<?, ?it/s]/opt/conda/lib/python3.9/site-packages/bitsandbytes-0.41.0-py3.9.egg/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
/opt/conda/lib/python3.9/site-packages/bitsandbytes-0.41.0-py3.9.egg/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
/opt/conda/lib/python3.9/site-packages/bitsandbytes-0.41.0-py3.9.egg/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
/opt/conda/lib/python3.9/site-packages/bitsandbytes-0.41.0-py3.9.egg/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
{'loss': 2.249, 'learning_rate': 2.9999999999999997e-05, 'epoch': 0.03}                                                                                                                               
{'loss': 2.1927, 'learning_rate': 5.6999999999999996e-05, 'epoch': 0.05}                                                                                                                              
{'loss': 2.0813, 'learning_rate': 7.8e-05, 'epoch': 0.08}                                                                                                                                             
{'loss': 1.7206, 'learning_rate': 0.00010799999999999998, 'epoch': 0.1}                                                                                                                               
 11%|████████████████▋                                                                                                                               11%|███████████▋                                                                                                | 42/388 [10:50<1:27:2

上図は4枚のカードの出力結果で、ビデオメモリの使用量は以下の通りです。 

-------------------------------+----------------------+----------------------+
|   0  NVIDIA A100-SXM...  On   | 00000000:47:00.0 Off |                    0 |
| N/A   60C    P0   322W / 400W |  36944MiB / 81920MiB |     89%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:4B:00.0 Off |                    0 |
| N/A   61C    P0   321W / 400W |  34204MiB / 81920MiB |     97%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM...  On   | 00000000:89:00.0 Off |                    0 |
| N/A   62C    P0   349W / 400W |  34200MiB / 81920MiB |     98%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM...  On   | 00000000:8E:00.0 Off |                    0 |
| N/A   63C    P0   261W / 400W |  33882MiB / 81920MiB |     89%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

HuggingFace 形式にエクスポート:

ダウンロード可能:  Angainor/alpaca-lora-13b    lora_weights for Hugging Face

export_hf_checkpoint.pyファイルを変更します。

import os

import torch
import transformers
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer  # noqa: F402

BASE_MODEL = os.environ.get("BASE_MODEL", "/disk1/llama-13b")
LORA_MODEL = os.environ.get("LORA_MODEL", "./alpaca-lora-13b")
HF_CHECKPOINT = os.environ.get("HF_CHECKPOINT", "./hf_ckpt")

tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

base_model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=False,
    torch_dtype=torch.float16,
    device_map={"": "cpu"},
)

first_weight = base_model.model.layers[0].self_attn.q_proj.weight
first_weight_old = first_weight.clone()

lora_model = PeftModel.from_pretrained(
    base_model,
    LORA_MODEL,
    device_map={"": "cpu"},
    torch_dtype=torch.float16,
)

lora_weight = lora_model.base_model.model.model.layers[
    0
].self_attn.q_proj.weight

assert torch.allclose(first_weight_old, first_weight)

# merge weights - new merging method from peft
lora_model = lora_model.merge_and_unload()

lora_model.train(False)

# did we do anything?
assert not torch.allclose(first_weight_old, first_weight)

lora_model_sd = lora_model.state_dict()
deloreanized_sd = {
    k.replace("base_model.model.", ""): v
    for k, v in lora_model_sd.items()
    if "lora" not in k
}

LlamaForCausalLM.save_pretrained(
    base_model, HF_CHECKPOINT, state_dict=deloreanized_sd, max_shard_size="400MB"
)

pythonエクスポート_hf_checkpoint.py

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'.
The class this function is called from is 'LlamaTokenizer'.
You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:26<00:00,  1.56it/s]

モデル出力ファイルを表示します。

hf_ckpt/
├── config.json
├── generation_config.json
├── pytorch_model-00001-of-00082.bin
├── pytorch_model-00002-of-00082.bin
├── pytorch_model-00003-of-00082.bin
├── pytorch_model-00004-of-00082.bin
├── pytorch_model-00005-of-00082.bin
├── pytorch_model-00006-of-00082.bin
├── pytorch_model-00007-of-00082.bin
├── pytorch_model-00008-of-00082.bin
├── pytorch_model-00009-of-00082.bin
├── pytorch_model-00010-of-00082.bin
├── pytorch_model-00011-of-00082.bin
├── pytorch_model-00012-of-00082.bin
├── pytorch_model-00013-of-00082.bin
├── pytorch_model-00014-of-00082.bin
├── pytorch_model-00015-of-00082.bin
├── pytorch_model-00016-of-00082.bin
├── pytorch_model-00017-of-00082.bin
├── pytorch_model-00018-of-00082.bin
├── pytorch_model-00019-of-00082.bin
├── pytorch_model-00020-of-00082.bin
├── pytorch_model-00021-of-00082.bin
├── pytorch_model-00022-of-00082.bin
├── pytorch_model-00023-of-00082.bin
├── pytorch_model-00024-of-00082.bin
├── pytorch_model-00025-of-00082.bin
├── pytorch_model-00026-of-00082.bin
├── pytorch_model-00027-of-00082.bin
├── pytorch_model-00028-of-00082.bin
├── pytorch_model-00029-of-00082.bin
├── pytorch_model-00030-of-00082.bin
├── pytorch_model-00031-of-00082.bin
├── pytorch_model-00032-of-00082.bin
├── pytorch_model-00033-of-00082.bin
├── pytorch_model-00034-of-00082.bin
├── pytorch_model-00035-of-00082.bin
├── pytorch_model-00036-of-00082.bin
├── pytorch_model-00037-of-00082.bin
├── pytorch_model-00038-of-00082.bin
├── pytorch_model-00039-of-00082.bin
├── pytorch_model-00040-of-00082.bin
├── pytorch_model-00041-of-00082.bin
├── pytorch_model-00042-of-00082.bin
├── pytorch_model-00043-of-00082.bin
├── pytorch_model-00044-of-00082.bin
├── pytorch_model-00045-of-00082.bin
├── pytorch_model-00046-of-00082.bin
├── pytorch_model-00047-of-00082.bin
├── pytorch_model-00048-of-00082.bin
├── pytorch_model-00049-of-00082.bin
├── pytorch_model-00050-of-00082.bin
├── pytorch_model-00051-of-00082.bin
├── pytorch_model-00052-of-00082.bin
├── pytorch_model-00053-of-00082.bin
├── pytorch_model-00054-of-00082.bin
├── pytorch_model-00055-of-00082.bin
├── pytorch_model-00056-of-00082.bin
├── pytorch_model-00057-of-00082.bin
├── pytorch_model-00058-of-00082.bin
├── pytorch_model-00059-of-00082.bin
├── pytorch_model-00060-of-00082.bin
├── pytorch_model-00061-of-00082.bin
├── pytorch_model-00062-of-00082.bin
├── pytorch_model-00063-of-00082.bin
├── pytorch_model-00064-of-00082.bin
├── pytorch_model-00065-of-00082.bin
├── pytorch_model-00066-of-00082.bin
├── pytorch_model-00067-of-00082.bin
├── pytorch_model-00068-of-00082.bin
├── pytorch_model-00069-of-00082.bin
├── pytorch_model-00070-of-00082.bin
├── pytorch_model-00071-of-00082.bin
├── pytorch_model-00072-of-00082.bin
├── pytorch_model-00073-of-00082.bin
├── pytorch_model-00074-of-00082.bin
├── pytorch_model-00075-of-00082.bin
├── pytorch_model-00076-of-00082.bin
├── pytorch_model-00077-of-00082.bin
├── pytorch_model-00078-of-00082.bin
├── pytorch_model-00079-of-00082.bin
├── pytorch_model-00080-of-00082.bin
├── pytorch_model-00081-of-00082.bin
├── pytorch_model-00082-of-00082.bin
└── pytorch_model.bin.index.json

0 directories, 85 files

PyTorch state_dicts としてエクスポートします

export_state_dict_checkpoint.pyファイルを変更します。

参考文献

おすすめ

転載: blog.csdn.net/zwqjoy/article/details/131920540