ChatGLM2-6B source code analysis./ptuning/main.py (2)

    # Get the column names for input/target.
    prompt_column = data_args.prompt_column
    response_column = data_args.response_column
    history_column = data_args.history_column
    
    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length

    def preprocess_function_eval(examples):
        inputs, targets = [], []
        for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query = examples[prompt_column][i]
                history = examples[history_column][i] if history_column is not None else None
                prompt = tokenizer.build_prompt(query, history)
                inputs.append(prompt)
                targets.append(examples[response_column][i])

        inputs = [prefix + inp for inp in inputs]
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
        labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)

        if data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]

        return model_inputs

    def preprocess_function_train(examples):
        max_seq_length = data_args.max_source_length + data_args.max_target_length + 1

        model_inputs = {
            "input_ids": [],
            "labels": [],
        }
        for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query, answer = examples[prompt_column][i], examples[response_column][i]

                history = examples[history_column][i] if history_column is not None else None
                prompt = tokenizer.build_prompt(query, history)

                prompt = prefix + prompt
                a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
                                         max_length=data_args.max_source_length)
                b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
                                         max_length=data_args.max_target_length)

                context_length = len(a_ids)
                input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
                labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]
                
                pad_len = max_seq_length - len(input_ids)
                input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
                labels = labels + [tokenizer.pad_token_id] * pad_len
                if data_args.ignore_pad_token_for_loss:
                    labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]

                model_inputs["input_ids"].append(input_ids)
                model_inputs["labels"].append(labels)

        return model_inputs
    
    def print_dataset_example(example):
        print("input_ids", example["input_ids"])
        print("inputs", tokenizer.decode(example["input_ids"]))
        print("label_ids", example["labels"])
        print("labels", tokenizer.decode(example["labels"]))

    if training_args.do_train:
        if "train" not in raw_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = raw_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function_train,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
        print_dataset_example(train_dataset[0])

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
        if "validation" not in raw_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = raw_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function_eval,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
        print_dataset_example(eval_dataset[0])

    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
        if "test" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = raw_datasets["test"]
        if data_args.max_predict_samples is not None:
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function_eval,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )
        print_dataset_example(predict_dataset[0])

    # Data collator
    label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of=None,
        padding=False
    )

    # Metric
    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        score_dict = {
            "rouge-1": [],
            "rouge-2": [],
            "rouge-l": [],
            "bleu-4": []
        }
        for pred, label in zip(decoded_preds, decoded_labels):
            hypothesis = list(jieba.cut(pred))
            reference = list(jieba.cut(label))
            rouge = Rouge()
            scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
            result = scores[0]
            
            for k, v in result.items():
                score_dict[k].append(round(v["f"] * 100, 4))
            bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
            score_dict["bleu-4"].append(round(bleu_score * 100, 4))

        for k, v in score_dict.items():
            score_dict[k] = float(np.mean(v))
        return score_dict

    # Override the decoding parameters of Seq2SeqTrainer
    training_args.generation_max_length = (
        training_args.generation_max_length
        if training_args.generation_max_length is not None
        else data_args.val_max_target_length
    )
    training_args.generation_num_beams = (
        data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
    )
    # Initialize our Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
        save_changed=model_args.pre_seq_len is not None
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        # elif last_checkpoint is not None:
        #     checkpoint = last_checkpoint
        model.gradient_checkpointing_enable()
        model.enable_input_require_grads()
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        # trainer.save_model()  # Saves the tokenizer too for easy upload

        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    results = {}
    max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    if training_args.do_predict:
        logger.info("*** Predict ***")
        predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
        metrics = predict_results.metrics
        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))

        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)

        if trainer.is_world_process_zero():
            if training_args.predict_with_generate:
                predictions = tokenizer.batch_decode(
                    predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
                predictions = [pred.strip() for pred in predictions]
                labels = tokenizer.batch_decode(
                    predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
                labels = [label.strip() for label in labels]
                output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
                with open(output_prediction_file, "w", encoding="utf-8") as writer:
                    for p, l in zip(predictions, labels):
                        res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
                        writer.write(f"{res}\n")
    return results


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()
  1. prompt_column, response_column, history_column : These variables are defined as the column names used to read the training data. prompt_column and response_column are columns for questions and answers respectively, and history_column is a column for chat records.

  2. max_target_length : This variable refers to the maximum length of the prediction.

  3. preprocess_function_eval : This is a preprocessing function used to process the data during the evaluation phase. It creates the input and target lists, then iterates over each example in the dataset. For each example, it checks for a prompt and response, then uses a tokenizer to convert the prompt and history into a format the model can understand. Then, all inputs are prepended to a prefix and encoded with a tokenizer. Finally, do the same for the target and add the processed input and target to the model input.

  4. preprocess_function_train : This is a preprocessing function used to process the data during the training phase. It is processed in a similar way to eval's preprocessing function, but there are some differences, for example it also adds an eos(end of sentence) token to the end of the input and label, and ensures that the length of both input and label conforms to the maximum sequence length.

  5. print_dataset_example : This function is used to print an example from the dataset.

  6. training_args.do_train : This is a conditional statement, if do_train in the training parameter is set to True, then it will perform preprocessing of the training data and print an example of the training data.

  7. training_args.do_eval : This is also a conditional statement, if do_eval in the training parameter is set to True, then it will perform preprocessing of the validation data and print an example of the validation data.

  8. training_args.do_predict : It is also a conditional statement, if do_predict in the training parameter is set to True, then it will perform preprocessing of the test data and print an example of the test data.

  9. label_pad_token_id, data_collator : These variables are defined as tools for processing sequence-to-sequence tasks. label_pad_token_id is the ID of the token used to fill the label, and data_collator is used to process batch data.

  10. compute_metrics : This function is used to compute evaluation metrics. It first decodes predictions and labels, then computes ROUGE and BLEU scores.

  11. trainer : This variable is a Seq2SeqTrainer object, which is used to train the model.

  12. if training_args.do_train : If do_train in the training parameters is set to True, then it will execute the training and save the training indicators and status.

  13. if training_args.do_eval : If do_eval in the training parameters is set to True, then it will perform the evaluation and save the evaluated metrics.

  14. if training_args.do_predict : If do_predict in the training parameters is set to True, then it will perform the prediction and save the predicted metrics. Finally, it also writes the predicted results to a file.

  15. def _mp_fn(index) : This function is a function for TPUs, if run on TPU, this function will be called.

  16. if name == " main " : This judgment statement is used to judge whether it is the running entry of the script, and if so, execute the main() function.

The entire script is the code used to train and evaluate a chatbot model. The script first preprocesses the data, then defines a model trainer and uses it to train the model. After training, the script evaluates and predicts using the model, and saves the predictions to a file. Throughout the process, the script uses many tools related to sequence-to-sequence tasks, such as data groomers and evaluation metric calculation functions.

Of course, let's parse this code line by line.

  1. prompt_column = data_args.prompt_column: Obtain the prompt column name from the data parameter, that is, the column used to ask questions.
  2. response_column = data_args.response_column: Get the name of the answer column from the data parameter, that is, the column that is the answer or target.
  3. history_column = data_args.history_column: Get the name of the historical dialog column from the data parameter, if present, these historical dialogs will be used as the context of the question.

The following are preprocessing functions, which are used to format and tokenize the input and target columns. The formatted results will be used for model training and validation.

  1. preprocess_function_evaland preprocess_function_train: These two functions prepare data for evaluation and training. They extract questions and answers from sample data, formatting and tokenizing them as necessary. They then add the input and target to model_inputsa list, and return the list.

In the code that follows, we process the datasets separately, depending on whether we want to train, evaluate, or predict, and whether the desired part (train, validation, or test) is included in the provided dataset.

  1. if training_args.do_train:: If the training flag is set, then it needs to check if a training dataset is provided and then preprocess as needed. Then print out the first training example.

  2. if training_args.do_eval:: Similarly, if the evaluation flag is set, then it needs to check if an evaluation dataset is provided and preprocess it. Then print out the first evaluation sample.

  3. if training_args.do_predict:: For prediction, we need to check if a test dataset is provided and preprocess it. Then print out the first test example.

  4. data_collator = DataCollatorForSeq2Seq(...): Creates a data collator that assembles preprocessed input data into batches that can be fed directly to the model.

Next is the calculation function of the evaluation index. This function compares the model's predictions with the actual labels, then computes and returns the metric score.

  1. compute_metrics: This function receives predictions and labels, first decodes , and then computes rouge and bleu scores.

Next, we override some decoding parameters, then initialize the trainer, and perform training, evaluation, and prediction.

  1. trainer = Seq2SeqTrainer(...): Initialize a trainer that will be used for training, evaluation and prediction.

  2. if training_args.do_train:: If the training flag is set, train and save the model and metrics after training.

  3. if training_args.do_eval:: If the evaluation flag is set, evaluate it, and record and save the evaluation index.

  4. if training_args.do_predict:: If the forecast flag is set, make a forecast, and record and save the forecast metrics. If the predictions were done using a generative method, save the predictions and labels to a file.

Finally, if the script is running as the main script, mainthe function is called.

  1. if __name__ == "__main__":: If this script is run as the main script, mainthe function is called. This is a common pattern in Python to check whether a script is run directly or imported as a module. __name__The value will only be when the script is run directly "__main__", so only in that case mainthe function will be called.

for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query = examples[prompt_column][i]
                history = examples[history_column][i] if history_column is not None else None
                prompt = tokenizer.build_prompt(query, history)
                inputs.append(prompt)
                targets.append(examples[response_column][i])

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/131621397