LLM - Model、Data、Training、Generate Agruments 超参解析

目录

一.引言

二.常用参数

◆ ModelArguments

◆ DataArguments

◆ TrainingArguments

◆ GeneratingArguments

三.代码实现

◆ Python 代码

◆ Shell 代码

四.总结


一.引言

LLM 相关训练框架都会引入 ModelArguments、DataArguments、TrainingArguments、GeneratingArguments 并通过 Transformer.HfArgumentParser 进行整合,实现了两行代码处理训练全程的参数问题。

ModelArguments - 模型参数

DataArguments - 数据集参数

TrainingArguments - 训练参数

GeneratingArguments - 生成参数

二.常用参数

◆ ModelArguments

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")

ModelArguments 主要存储模型加载与配置的相关参数,一般还有以下参数,大家可以自定义:

参数名称 默认 类型 含义
model_name_or_path None str 模型地址或名称
cache_dir None str 缓存地址
use_fast_tokenizer False bool 使用快速 tokenizer
padding_side left str 模型 pad 选择
quantization_bit None int 量化 bit 选择
compute_type None torch.dtype 模型参数类型
checkpoint_dir None str 微调参数地址
mode None str reward、lora
plot_loss False bool 打印训练 Loss

◆ DataArguments

@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )

DataArguments 主要负责数据集相关参数,数据集通过 dataset 构成,通常包含下述参数:

参数名称 默认 类型 含义
data_path None str 数据集地址
process_num None int 并行处理
max_source_length 512 int source 最大长度
max_target_length 512 int target 最大长度
max_samples None int 最大样本数
ignore_pad_token None int loss 计算是否忽略
prompt_template None str 样本生成 prompt 模板

◆ TrainingArguments

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = field(default=False)
    output_dir: str = field(default="")

TrainingArguments 主要存储模型微调,训练相关的参数:

参数名称 默认 类型 含义
finetuning_type lora str 微调类型
lora_target q_proj,v_proj str 微调 Layer
lora_rank 8 int lora 降维维度
lora_alpha 32.0 float lora 微调比例因子
lora_dropout 0.1 float dropout 比例
num_hidden_layers 32 int Decode 数量
num_layer_trainable 3 int freeze layer 数量
name_module_trainable mlp str freeze 训练层选择
output_dir None str 模型输出地址

◆ GeneratingArguments

@dataclass
class GeneratingArguments:
    do_sample: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
    )

GeneratingArguments 主要负责 model generate 生成的配置:

参数名称 默认 类型 含义
do_sample True bool 采样或贪心
temperature 0.95 float 调整下一个 token 的概率
top_p 0.7 float token 概率 top 区间
top_k 50 int token 词库数量
num_beams 1 int beam search 数量
max_length None int 最大生成 token 数
max_new_tokens 512 int 最多新 toekn 生成数
repatition_penalty 1.0 float 重复惩罚
length_penalty 1.0 float 长度惩罚

之前单独整理了生成的参数和代码,可以参考: LLM - model batch generate 生成文本

三.代码实现

◆ Python 代码

from typing import Optional
from dataclasses import dataclass, field
import transformers


...

    添加上述的 Argument Class

...


if __name__ == '__main__':
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GeneratingArguments))
    model_args, data_args, training_args, generate_args = parser.parse_args_into_dataclasses()

    print(model_args)
    print(data_args)
    print(training_args)
    print(generate_args)

两行搞定多类参数,参数对应属性使用 args.xxx 调用即可。

Shell 代码

#!/bin/bash

python GetConfigByArgs.py \
    --report_to "none" \
    --data_path "data/belle_chat_ramdon_10k.json" \
    --model_name_or_path "baichuan-inc/Baichuan2-7B-Base" \
    --output_dir "output" \
    --model_max_length 512 \
    --num_train_epochs 4 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --save_strategy epoch \
    --learning_rate 2e-5 \
    --lr_scheduler_type constant \
    --adam_beta1 0.9 \
    --adam_beta2 0.98 \
    --adam_epsilon 1e-8 \
    --max_grad_norm 1.0 \
    --weight_decay 1e-4 \
    --warmup_ratio 0.0 \
    --logging_steps 1 \
    --gradient_checkpointing True \
    --deepspeed ds_config.json \
    --bf16 False \
    --tf32 False

通过 -- 传递我们需要的参数即可。

四.总结

这个没啥总结的了,就是觉得写法比较优雅,后面自己的脚本也可以借用。

猜你喜欢

转载自blog.csdn.net/BIT_666/article/details/132755841