ChatGLM2-6B源码解析./ptuning/main.py (一)

from datasets import load_dataset
import jieba 
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import torch

import transformers
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
)
from trainer_seq2seq import Seq2SeqTrainer

from arguments import ModelArguments, DataTrainingArguments

logger = logging.getLogger(__name__)

def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    # datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Load dataset
    data_files = {}
    if data_args.train_file is not None:
        data_files["train"] = data_args.train_file
        extension = data_args.train_file.split(".")[-1]
    if data_args.validation_file is not None:
        data_files["validation"] = data_args.validation_file
        extension = data_args.validation_file.split(".")[-1]
    if data_args.test_file is not None:
        data_files["test"] = data_args.test_file
        extension = data_args.test_file.split(".")[-1]

    raw_datasets = load_dataset(
        extension,
        data_files=data_files,
        cache_dir=model_args.cache_dir,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
    config.pre_seq_len = model_args.pre_seq_len
    config.prefix_projection = model_args.prefix_projection

    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

    if model_args.ptuning_checkpoint is not None:
        # Evaluation
        # Loading extra state dict of prefix encoder
        model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
        prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    else:
        model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)

    if model_args.quantization_bit is not None:
        print(f"Quantized to {model_args.quantization_bit} bit")
        model = model.quantize(model_args.quantization_bit)
    if model_args.pre_seq_len is not None:
        # P-tuning v2
        model = model.half()
        model.transformer.prefix_encoder.float()
    else:
        # Finetune
        model = model.float()

    prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    if training_args.do_train:
        column_names = raw_datasets["train"].column_names
    elif training_args.do_eval:
        column_names = raw_datasets["validation"].column_names
    elif training_args.do_predict:
        column_names = raw_datasets["test"].column_names
    else:
        logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
        return

这段代码的大致目的是为了训练、评估或预测一个Seq2Seq模型。以下是这段代码的逐行解释:

  1. from datasets import load_dataset:从datasets库中导入load_dataset函数,这个函数用于加载各种预处理后的数据集。

  2. import jieba:导入jieba库,一个中文分词库。

  3. from rouge_chinese import Rouge:从rouge_chinese模块中导入Rouge类,用于计算中文的Rouge分数,这是评估摘要和翻译的一个常用指标。

  4. from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction:从nltk库中导入计算BLEU分数的函数,BLEU分数也是评估翻译质量的一个常用指标。

  5. import torch:导入torch库,这是一个深度学习库,提供了各种深度学习模型和函数。

  6. import transformersfrom transformers import (...):导入transformers库和一些具体的类或函数。Transformers库包含了许多预训练的深度学习模型和相关工具。

  7. from trainer_seq2seq import Seq2SeqTrainer:从trainer_seq2seq模块中导入Seq2SeqTrainer类,这个类用于训练Seq2Seq模型。

  8. from arguments import ModelArguments, DataTrainingArguments:从arguments模块中导入ModelArguments和DataTrainingArguments类,这两个类分别用于解析模型和数据训练的命令行参数。

  9. logger = logging.getLogger(__name__):获取一个logger对象,用于输出日志信息。

  10. def main()::定义主函数,整个程序的入口。

  11. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)):创建一个解析器对象,用于解析命令行参数。

  12. if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):else:块:判断命令行参数是否为一个.json文件,如果是,就从.json文件中读取参数;否则,直接从命令行参数解析。

  13. logging.basicConfig(...):设置日志的输出格式和处理器。

  14. 如果training_args.should_log为真,那么设置日志等级为info,否则,根据training_args的内容设置日志等级。

  15. 输出一些关于训练的信息,例如设备信息、是否分布式训练等。

  16. set_seed(training_args.seed):设置随机种子,保证实验的可复现性。

  17. 从命令行参数中获取训练、验证和测试数据文件的路径,并用load_dataset函数加载这些数据集。

  18. 从预训练模型加载配置、模型和tokenizer。

  19. 如果提供了ptuning_checkpoint参数,那么加载这个参数指定的模型;否则,加载预训练模型。

  20. 如果提供了quantization_bit参数,那么对模型进行量化操作。

  21. 如果提供了pre_seq_len参数,那么对模型进行半精度(half)计算,否则,进行单精度(float)计算。

  22. 从命令行参数中获取source_prefix参数,如果没有提供,那么设为空字符串。

  23. 根据命令行参数中的do_train、do_eval和do_predict参数的值决定是进行训练、评估还是预测,并对数据集进行相应的预处理。

  24. 如果没有提供do_train、do_eval和do_predict参数,那么输出信息并退出。

让我们更详细地分析这段代码:

  1. from datasets import load_dataset: 从 Hugging Face 的 datasets 库中导入 load_dataset 函数,用于加载各种预处理后的数据集。

  2. import jieba: 导入jieba,它是一个用于中文分词的Python库。

  3. from rouge_chinese import Rouge: 从 rouge_chinese 模块中导入 Rouge 类,这个类可以用来计算 Rouge 分数,它是一种用来评估机器生成文本(如机器翻译或文本摘要)与人类参考文本之间相似度的指标。

  4. from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction: 从 nltk.translate.bleu_score 模块中导入 sentence_bleuSmoothingFunctionsentence_bleu 是用来计算单个句子的 BLEU 分数的函数,而 SmoothingFunction 是用来处理BLEU分数计算过程中出现的0分情况。

  5. import torch: 导入 PyTorch 库,这是一个常用的深度学习框架。

  6. import transformersfrom transformers import (...): 这两行导入了 transformers 库及其一些子模块。transformers 库提供了许多预训练的神经网络模型,可以用于各种自然语言处理任务。

  7. from trainer_seq2seq import Seq2SeqTrainer: 从 trainer_seq2seq 模块导入 Seq2SeqTrainer 类,这个类是用来训练序列到序列(seq2seq)模型的。

  8. from arguments import ModelArguments, DataTrainingArguments: 这行代码从 arguments 模块导入了两个类,这两个类用于解析和处理命令行参数。

  9. logger = logging.getLogger(__name__): 创建一个记录器(logger),这个记录器可以用来记录脚本的运行情况。

  10. def main():: 定义主函数。

  11. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)): 创建一个 HfArgumentParser 对象,它将解析和处理 ModelArgumentsDataTrainingArgumentsSeq2SeqTrainingArguments 这三个类的实例。

  12. if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):: 这行代码检查脚本的命令行参数是否为一个 .json 文件。如果是,那么将会从这个文件中读取参数。

  13. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])): 这行代码读取 .json 文件中的参数,并将其分别赋值给 model_argsdata_argstraining_args

  14. else: model_args, data_args, training_args = parser.parse_args_into_dataclasses(): 如果命令行参数不是一个 .json 文件,那么这行代码将会直接从命令行参数中解析出参数。

  15. logging.basicConfig(...):设置日志的基础配置,包括日志的格式、日期格式以及处理器。

  16. if training_args.should_log: transformers.utils.logging.set_verbosity_info(): 如果 training_args.should_log 为真(即需要记录日志),那么设置日志等级为 info

  17. log_level = training_args.get_process_log_level(): 获取 training_args 中定义的日志等级。

  18. logger.setLevel(log_level): 设置 logger 的日志等级。

  19. transformers.utils.logging.set_verbosity(log_level): 设置 transformers 模块的日志等级。

  20. transformers.utils.logging.enable_default_handler()transformers.utils.logging.enable_explicit_format(): 启用默认的日志处理器和明确的日志格式。

  21. logger.warning(...)logger.info(...): 使用 logger 输出一些警告和信息。

  22. set_seed(training_args.seed): 设置随机种子以保证实验结果的一致性。

  23. 这一大段代码主要是根据 data_args(数据相关参数)来加载训练集、验证集和测试集。

  24. 这一大段代码主要是根据 model_args(模型相关参数)来加载预训练模型和相关配置。

  25. 这一大段代码主要是根据参数来确定模型的量化和精度设置。

  26. prefix = data_args.source_prefix if data_args.source_prefix is not None else "": 设置数据集的前缀,如果没有设置则默认为空字符串。

  27. 这一大段代码根据 training_args(训练相关参数)来确定是进行训练、验证还是测试,并对数据集进行相应的预处理。

  28. if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:return: 如果没有进行训练、验证和预测的命令,则输出相应的信息并结束程序。

  1. def preprocess_function(examples)::定义一个预处理函数,对输入的例子进行预处理。

  2. inputs = [prefix + ex[source_lang] for ex in examples[source_lang]]: 把源语言的文本与前缀组合在一起作为输入。

  3. targets = [ex[target_lang] for ex in examples[target_lang]]:将目标语言的文本作为目标输出。

  4. model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True):使用tokenizer对输入进行处理,包括分词,添加特殊标记,转换为模型需要的输入格式等。

  5. # Setup the tokenizer for targets:设置目标的tokenizer。

  6. with tokenizer.as_target_tokenizer()::设置tokenizer为目标语言的tokenizer。

  7. labels = tokenizer(targets, max_length=data_args.max_target_length, padding="max_length", truncation=True):对目标文本进行同样的处理。

  8. model_inputs["labels"] = labels["input_ids"]:将处理后的目标文本的输入ID作为标签。

  9. return model_inputs:返回处理后的模型输入。

  10. if training_args.do_train::如果进行训练。

  11. if "train" not in raw_datasets::如果训练集不在原始数据集中。

  12. raise ValueError("--do_train requires a train dataset"):抛出错误。

  13. train_dataset = raw_datasets["train"]:获取训练集。

  14. if data_args.max_train_samples is not None::如果设置了训练样本的最大数量。

  15. train_dataset = train_dataset.select(range(data_args.max_train_samples)):从训练集中选择一定数量的样本。

  16. train_dataset = train_dataset.map(:对训练集进行映射处理,即对每个样本应用预处理函数。

  17. preprocess_function,:调用预处理函数。

  18. batched=True,:批量处理。

  19. num_proc=data_args.preprocessing_num_workers,:设置处理的进程数。

  20. remove_columns=column_names,:移除原来的列。

  21. load_from_cache_file=not data_args.overwrite_cache,:如果不覆盖缓存,则从缓存中加载。

  22. desc="running tokenizer on train dataset",:设置描述信息。

  23. )

  24. if training_args.do_eval::如果进行评估。

  25. max_target_length = data_args.max_target_length:设置目标文本的最大长度。

  26. if data_args.val_max_target_length is not None::如果设置了验证目标文本的最大长度。

  27. max_target_length = data_args.val_max_target_length:使用设置的验证目标文本的最大长度。

  28. if "validation" not in raw_datasets::如果验证集不在原始数据集中。

  29. raise ValueError("--do_eval requires a validation dataset"):抛出错误。

  30. eval_dataset = raw_datasets["validation"]:获取验证集。

  31. if data_args.max_eval_samples is not None::如果设置了评估样本的最大数量。

  32. eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)):从验证集中选择一定数量的样本。

  33. eval_dataset = eval_dataset.map(:对验证集进行映射处理,即对每个样本应用预处理函数。

  34. preprocess_function,:调用预处理函数。

  35. batched=True,:批量处理。

  36. num_proc=data_args.preprocessing_num_workers,:设置处理的进程数。

  37. remove_columns=column_names,:移除原来的列。

  38. load_from_cache_file=not data_args.overwrite_cache,:如果不覆盖缓存,则从缓存中加载。

  39. desc="running tokenizer on validation dataset",:设置描述信息。

  40. )

  41. if training_args.do_predict::如果进行预测。

  42. max_target_length = data_args.max_target_length:设置目标文本的最大长度。

  43. if data_args.predict_with_generate and data_args.test_max_target_length is not None::如果设置了预测目标文本的最大长度。

  44. max_target_length = data_args.test_max_target_length:使用设置的预测目标文本的最大长度。

  45. if "test" not in raw_datasets::如果测试集不在原始数据集中。

  46. raise ValueError("--do_predict requires a test dataset"):抛出错误。

  47. predict_dataset = raw_datasets["test"]:获取测试集。

  48. predict_dataset = predict_dataset.map(:对测试集进行映射处理,即对每个样本应用预处理函数。

  49. preprocess_function,:调用预处理函数。

  50. batched=True,:批量处理。

  51. num_proc=data_args.preprocessing_num_workers,:设置处理的进程数。

  52. remove_columns=column_names,:移除原来的列。

  53. load_from_cache_file=not data_args.overwrite_cache,:如果不覆盖缓存,则从缓存中加载。

  54. desc="running tokenizer on prediction dataset",:设置描述信息。

  55. )

这段代码主要是根据所提供的参数进行数据的预处理,包括训练集、验证集和测试集的处理,并根据命令行参数决定是否进行训练、验证和预测。在处理数据集时,都会调用 preprocess_function 函数对每个样本进行预处理,然后使用 map 函数应用这个预处理函数。

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/131617133