class Seq2SeqTrainerForChatGLM(PeftTrainer):
r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""
def save_predictions(
self,
predict_results: PredictionOutput,
tokenizer: PreTrainedTokenizer
) -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
preds = preds[:, labels.shape[1]:] # remove prompts
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for pred, label in zip(preds, labels):
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))
Here's a Python snippet that defines a Seq2SeqTrainerForChatGLM
class called , which it inherits fromPeftTrainer
. This class specifically handles the training of sequence-to-sequence models , as well as the computation of generative metrics such as BLEU and ROUGE.
Next, we will explain this code line by line:
-
class Seq2SeqTrainerForChatGLM(PeftTrainer):
- This line defines a new class,Seq2SeqTrainerForChatGLM
, that inherits fromPeftTrainer
class. -
def save_predictions(self, predict_results: PredictionOutput, tokenizer: PreTrainedTokenizer) -> None:
- This is asave_predictions
method that accepts two parameters, one ispredict_results
( prediction result, type isPredictionOutput
) and the other istokenizer
(pretrained tokenizer, type isPreTrainedTokenizer
). -
if not self.is_world_process_zero(): return
- This line checks if the current process is the main process (process number 0 ). If not, the method terminates. -
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
- This line creates a new array of predictors, where the predictors for non-ignored indices are preserved, and the predictors for ignored indices are replaced with the IDs of the filled symbols. -
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
- This line creates a new label array, where the label values of non-ignored indices are kept , and the label values of ignored indices are replaced with the IDs of the filled symbols. -
preds = preds[:, labels.shape[1]:]
- This line removes the hints from the predictions, the hints are at the beginning of the predictions array . -
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
- This line uses the tokenizer to decode the predictor, remove special symbols, and remove leading and trailing whitespace. -
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
- This line decodes the tags using the tokenizer , removes special symbols, and removes leading and trailing whitespace. -
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
- This line defines the output file path of the prediction results. -
logger.info(f"Saving prediction results to {output_prediction_file}")
- This line sends a message to the log indicating which file the predictions will be saved to. -
with open(output_prediction_file, "w", encoding="utf-8") as writer:
- This line opens the forecast output file in write mode, ready to write forecast results. -
res: List[str] = []
- This line initializes an empty list ready to collect predictions for each row. -
for pred, label in zip(preds, labels):
- This line starts a loop over all predicted values and labels. -
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
- In the loop, this line converts each pair of predictions and labels in JSON format to a string, then adds to the result list. -
writer.write("\n".join(res))
- This line concatenates the list of results into a string, with each result separated by a newline, and writes it to the file.