ChatGLM2-6B源码解析./ptuning/main.py (二)

    # Get the column names for input/target.
    prompt_column = data_args.prompt_column
    response_column = data_args.response_column
    history_column = data_args.history_column
    
    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length

    def preprocess_function_eval(examples):
        inputs, targets = [], []
        for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query = examples[prompt_column][i]
                history = examples[history_column][i] if history_column is not None else None
                prompt = tokenizer.build_prompt(query, history)
                inputs.append(prompt)
                targets.append(examples[response_column][i])

        inputs = [prefix + inp for inp in inputs]
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
        labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)

        if data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]

        return model_inputs

    def preprocess_function_train(examples):
        max_seq_length = data_args.max_source_length + data_args.max_target_length + 1

        model_inputs = {
            "input_ids": [],
            "labels": [],
        }
        for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query, answer = examples[prompt_column][i], examples[response_column][i]

                history = examples[history_column][i] if history_column is not None else None
                prompt = tokenizer.build_prompt(query, history)

                prompt = prefix + prompt
                a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
                                         max_length=data_args.max_source_length)
                b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
                                         max_length=data_args.max_target_length)

                context_length = len(a_ids)
                input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
                labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]
                
                pad_len = max_seq_length - len(input_ids)
                input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
                labels = labels + [tokenizer.pad_token_id] * pad_len
                if data_args.ignore_pad_token_for_loss:
                    labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]

                model_inputs["input_ids"].append(input_ids)
                model_inputs["labels"].append(labels)

        return model_inputs
    
    def print_dataset_example(example):
        print("input_ids", example["input_ids"])
        print("inputs", tokenizer.decode(example["input_ids"]))
        print("label_ids", example["labels"])
        print("labels", tokenizer.decode(example["labels"]))

    if training_args.do_train:
        if "train" not in raw_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = raw_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function_train,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
        print_dataset_example(train_dataset[0])

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
        if "validation" not in raw_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = raw_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function_eval,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
        print_dataset_example(eval_dataset[0])

    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
        if "test" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = raw_datasets["test"]
        if data_args.max_predict_samples is not None:
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function_eval,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )
        print_dataset_example(predict_dataset[0])

    # Data collator
    label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of=None,
        padding=False
    )

    # Metric
    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        score_dict = {
            "rouge-1": [],
            "rouge-2": [],
            "rouge-l": [],
            "bleu-4": []
        }
        for pred, label in zip(decoded_preds, decoded_labels):
            hypothesis = list(jieba.cut(pred))
            reference = list(jieba.cut(label))
            rouge = Rouge()
            scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
            result = scores[0]
            
            for k, v in result.items():
                score_dict[k].append(round(v["f"] * 100, 4))
            bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
            score_dict["bleu-4"].append(round(bleu_score * 100, 4))

        for k, v in score_dict.items():
            score_dict[k] = float(np.mean(v))
        return score_dict

    # Override the decoding parameters of Seq2SeqTrainer
    training_args.generation_max_length = (
        training_args.generation_max_length
        if training_args.generation_max_length is not None
        else data_args.val_max_target_length
    )
    training_args.generation_num_beams = (
        data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
    )
    # Initialize our Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
        save_changed=model_args.pre_seq_len is not None
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        # elif last_checkpoint is not None:
        #     checkpoint = last_checkpoint
        model.gradient_checkpointing_enable()
        model.enable_input_require_grads()
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        # trainer.save_model()  # Saves the tokenizer too for easy upload

        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    results = {}
    max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    if training_args.do_predict:
        logger.info("*** Predict ***")
        predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
        metrics = predict_results.metrics
        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))

        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)

        if trainer.is_world_process_zero():
            if training_args.predict_with_generate:
                predictions = tokenizer.batch_decode(
                    predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
                predictions = [pred.strip() for pred in predictions]
                labels = tokenizer.batch_decode(
                    predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
                labels = [label.strip() for label in labels]
                output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
                with open(output_prediction_file, "w", encoding="utf-8") as writer:
                    for p, l in zip(predictions, labels):
                        res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
                        writer.write(f"{res}\n")
    return results


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()
  1. prompt_column, response_column, history_column: 这些变量被定义为用于读取训练数据的列名。prompt_column和response_column分别是提问和回答的列,history_column是聊天记录的列。

  2. max_target_length: 这个变量是指预测的最大长度。

  3. preprocess_function_eval: 这是一个预处理函数,用于在评估阶段对数据进行处理。它创建了输入和目标列表,然后迭代数据集中的每个示例。对于每个示例,它检查是否有prompt和response,然后使用tokenizer将prompt和history转换为模型可以理解的格式。然后,所有的输入都被添加到一个前缀,并用tokenizer进行编码。最后,对目标进行同样的处理,并将处理后的输入和目标加入到模型输入中。

  4. preprocess_function_train: 这是一个预处理函数,用于在训练阶段对数据进行处理。它的处理方式与eval的预处理函数类似,但有一些不同之处,例如它还添加了一个eos(end of sentence) token到输入和标签的末尾,并确保输入和标签的长度都符合最大序列长度。

  5. print_dataset_example: 这个函数用于打印数据集中的一个示例。

  6. training_args.do_train: 这是一个条件语句,如果训练参数中的do_train设定为True,那么它会执行训练数据的预处理并打印一个训练数据的示例。

  7. training_args.do_eval: 这也是一个条件语句,如果训练参数中的do_eval设定为True,那么它会执行验证数据的预处理并打印一个验证数据的示例。

  8. training_args.do_predict: 同样是一个条件语句,如果训练参数中的do_predict设定为True,那么它会执行测试数据的预处理并打印一个测试数据的示例。

  9. label_pad_token_id, data_collator: 这些变量被定义为处理序列到序列任务的工具。label_pad_token_id是用于填充标签的token的ID,data_collator用于处理批量数据。

  10. compute_metrics: 这个函数用于计算评估指标。它首先解码预测和标签,然后计算ROUGE和BLEU评分。

  11. trainer: 这个变量是一个Seq2SeqTrainer对象,它用于训练模型。

  12. if training_args.do_train: 如果训练参数中的do_train设定为True,那么它会执行训练,并保存训练的指标和状态。

  13. if training_args.do_eval: 如果训练参数中的do_eval设定为True,那么它会执行评估,并保存评估的指标。

  14. if training_args.do_predict: 如果训练参数中的do_predict设定为True,那么它会执行预测,并保存预测的指标。最后,它还将预测的结果写入文件。

  15. def _mp_fn(index): 这个函数是用于TPUs的函数,如果在TPU上运行,则会调用这个函数。

  16. if name == "main": 这个判断语句用来判断是否为脚本的运行入口,如果是,则执行main()函数。

整个脚本是用于训练和评估一个聊天机器人模型的代码。脚本先进行数据的预处理,然后定义一个模型训练器并使用它来训练模型。在训练后,脚本使用模型进行评估和预测,并将预测结果保存到文件中。在整个过程中,脚本使用了许多与序列到序列任务相关的工具,如数据整理器和评估指标计算函数。

当然,让我们逐行来解析这些代码。

  1. prompt_column = data_args.prompt_column:从数据参数中获取提示列名称,也就是用于提问的列。
  2. response_column = data_args.response_column:从数据参数中获取回答列的名称,也就是作为回答或目标的列。
  3. history_column = data_args.history_column:从数据参数中获取历史对话列的名称,如果存在的话,这些历史对话将被用作提问的上下文。

以下是预处理函数,它们用于将输入和目标列进行格式化和分词。格式化的结果将被用于模型的训练和验证。

  1. preprocess_function_evalpreprocess_function_train:这两个函数是为评估和训练准备数据的。它们从示例数据中提取问题和回答,并根据需要将其进行格式化和分词。然后它们会将输入和目标添加到model_inputs列表中,然后返回这个列表。

在接下来的代码中,我们根据是否要进行训练、评估或预测,以及提供的数据集中是否包含所需的部分(训练、验证或测试),来分别处理数据集。

  1. if training_args.do_train::如果设置了训练标志,那么就需要检查是否提供了训练数据集,然后根据需要进行预处理。然后打印出第一个训练样例。

  2. if training_args.do_eval::类似地,如果设置了评估标志,那么就需要检查是否提供了评估数据集,并进行预处理。然后打印出第一个评估样例。

  3. if training_args.do_predict::对于预测,我们需要检查是否提供了测试数据集,并进行预处理。然后打印出第一个测试样例。

  4. data_collator = DataCollatorForSeq2Seq(...):创建一个数据整理器,用于将预处理后的输入数据组装成可以直接喂入模型的批次。

接下来是评估指标的计算函数。这个函数将模型的预测结果与实际标签进行比较,然后计算并返回指标分数。

  1. compute_metrics:这个函数接收预测和标签,首先进行解码,然后计算rouge和bleu分数。

接着,我们覆盖一些解码参数,然后初始化训练器,并进行训练、评估和预测。

  1. trainer = Seq2SeqTrainer(...):初始化一个训练器,它将用于训练、评估和预测。

  2. if training_args.do_train::如果设置了训练标志,就进行训练,并在训练结束后保存模型和指标。

  3. if training_args.do_eval::如果设置了评估标志,就进行评估,并记录并保存评估指标。

  4. if training_args.do_predict::如果设置了预测标志,就进行预测,并记录并保存预测指标。如果预测是使用生成方法完成的,就将预测和标签保存到文件中。

最后,如果此脚本是作为主脚本运行的,就调用main函数。

  1. if __name__ == "__main__"::如果此脚本是作为主脚本运行的,就调用main函数。这是Python的一种常见模式,用于检查脚本是直接运行还是作为模块导入。只有在直接运行脚本时,__name__的值才会是"__main__",因此只有在这种情况下,才会调用main函数。

for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query = examples[prompt_column][i]
                history = examples[history_column][i] if history_column is not None else None
                prompt = tokenizer.build_prompt(query, history)
                inputs.append(prompt)
                targets.append(examples[response_column][i])

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/131621397