from datasets import load_dataset
import jieba
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import torch
import transformers
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainingArguments,
set_seed,
)
from trainer_seq2seq import Seq2SeqTrainer
from arguments import ModelArguments, DataTrainingArguments
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
# datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Load dataset
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
if model_args.ptuning_checkpoint is not None:
# Evaluation
# Loading extra state dict of prefix encoder
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
if model_args.quantization_bit is not None:
print(f"Quantized to {model_args.quantization_bit} bit")
model = model.quantize(model_args.quantization_bit)
if model_args.pre_seq_len is not None:
# P-tuning v2
model = model.half()
model.transformer.prefix_encoder.float()
else:
# Finetune
model = model.float()
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if training_args.do_train:
column_names = raw_datasets["train"].column_names
elif training_args.do_eval:
column_names = raw_datasets["validation"].column_names
elif training_args.do_predict:
column_names = raw_datasets["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
这段代码的大致目的是为了训练、评估或预测一个Seq2Seq模型。以下是这段代码的逐行解释:
-
from datasets import load_dataset
:从datasets库中导入load_dataset函数,这个函数用于加载各种预处理后的数据集。 -
import jieba
:导入jieba库,一个中文分词库。 -
from rouge_chinese import Rouge
:从rouge_chinese模块中导入Rouge类,用于计算中文的Rouge分数,这是评估摘要和翻译的一个常用指标。 -
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
:从nltk库中导入计算BLEU分数的函数,BLEU分数也是评估翻译质量的一个常用指标。 -
import torch
:导入torch库,这是一个深度学习库,提供了各种深度学习模型和函数。 -
import transformers
和from transformers import (...)
:导入transformers库和一些具体的类或函数。Transformers库包含了许多预训练的深度学习模型和相关工具。 -
from trainer_seq2seq import Seq2SeqTrainer
:从trainer_seq2seq模块中导入Seq2SeqTrainer类,这个类用于训练Seq2Seq模型。 -
from arguments import ModelArguments, DataTrainingArguments
:从arguments模块中导入ModelArguments和DataTrainingArguments类,这两个类分别用于解析模型和数据训练的命令行参数。 -
logger = logging.getLogger(__name__)
:获取一个logger对象,用于输出日志信息。 -
def main():
:定义主函数,整个程序的入口。 -
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
:创建一个解析器对象,用于解析命令行参数。 -
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
和else:
块:判断命令行参数是否为一个.json文件,如果是,就从.json文件中读取参数;否则,直接从命令行参数解析。 -
logging.basicConfig(...)
:设置日志的输出格式和处理器。 -
如果training_args.should_log为真,那么设置日志等级为info,否则,根据training_args的内容设置日志等级。
-
输出一些关于训练的信息,例如设备信息、是否分布式训练等。
-
set_seed(training_args.seed)
:设置随机种子,保证实验的可复现性。 -
从命令行参数中获取训练、验证和测试数据文件的路径,并用load_dataset函数加载这些数据集。
-
从预训练模型加载配置、模型和tokenizer。
-
如果提供了ptuning_checkpoint参数,那么加载这个参数指定的模型;否则,加载预训练模型。
-
如果提供了quantization_bit参数,那么对模型进行量化操作。
-
如果提供了pre_seq_len参数,那么对模型进行半精度(half)计算,否则,进行单精度(float)计算。
-
从命令行参数中获取source_prefix参数,如果没有提供,那么设为空字符串。
-
根据命令行参数中的do_train、do_eval和do_predict参数的值决定是进行训练、评估还是预测,并对数据集进行相应的预处理。
-
如果没有提供do_train、do_eval和do_predict参数,那么输出信息并退出。
让我们更详细地分析这段代码:
-
from datasets import load_dataset
: 从 Hugging Face 的datasets
库中导入load_dataset
函数,用于加载各种预处理后的数据集。 -
import jieba
: 导入jieba,它是一个用于中文分词的Python库。 -
from rouge_chinese import Rouge
: 从rouge_chinese
模块中导入Rouge
类,这个类可以用来计算 Rouge 分数,它是一种用来评估机器生成文本(如机器翻译或文本摘要)与人类参考文本之间相似度的指标。 -
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
: 从nltk.translate.bleu_score
模块中导入sentence_bleu
和SmoothingFunction
。sentence_bleu
是用来计算单个句子的 BLEU 分数的函数,而SmoothingFunction
是用来处理BLEU分数计算过程中出现的0分情况。 -
import torch
: 导入 PyTorch 库,这是一个常用的深度学习框架。 -
import transformers
和from transformers import (...)
: 这两行导入了transformers
库及其一些子模块。transformers
库提供了许多预训练的神经网络模型,可以用于各种自然语言处理任务。 -
from trainer_seq2seq import Seq2SeqTrainer
: 从trainer_seq2seq
模块导入Seq2SeqTrainer
类,这个类是用来训练序列到序列(seq2seq)模型的。 -
from arguments import ModelArguments, DataTrainingArguments
: 这行代码从arguments
模块导入了两个类,这两个类用于解析和处理命令行参数。 -
logger = logging.getLogger(__name__)
: 创建一个记录器(logger),这个记录器可以用来记录脚本的运行情况。 -
def main():
: 定义主函数。 -
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
: 创建一个HfArgumentParser
对象,它将解析和处理ModelArguments
、DataTrainingArguments
和Seq2SeqTrainingArguments
这三个类的实例。 -
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
: 这行代码检查脚本的命令行参数是否为一个.json
文件。如果是,那么将会从这个文件中读取参数。 -
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
: 这行代码读取.json
文件中的参数,并将其分别赋值给model_args
、data_args
和training_args
。 -
else: model_args, data_args, training_args = parser.parse_args_into_dataclasses()
: 如果命令行参数不是一个.json
文件,那么这行代码将会直接从命令行参数中解析出参数。 -
logging.basicConfig(...)
:设置日志的基础配置,包括日志的格式、日期格式以及处理器。 -
if training_args.should_log: transformers.utils.logging.set_verbosity_info()
: 如果training_args.should_log
为真(即需要记录日志),那么设置日志等级为 info。 -
log_level = training_args.get_process_log_level()
: 获取training_args
中定义的日志等级。 -
logger.setLevel(log_level)
: 设置logger
的日志等级。 -
transformers.utils.logging.set_verbosity(log_level)
: 设置transformers
模块的日志等级。 -
transformers.utils.logging.enable_default_handler()
和transformers.utils.logging.enable_explicit_format()
: 启用默认的日志处理器和明确的日志格式。 -
logger.warning(...)
和logger.info(...)
: 使用logger
输出一些警告和信息。 -
set_seed(training_args.seed)
: 设置随机种子以保证实验结果的一致性。 -
这一大段代码主要是根据
data_args
(数据相关参数)来加载训练集、验证集和测试集。 -
这一大段代码主要是根据
model_args
(模型相关参数)来加载预训练模型和相关配置。 -
这一大段代码主要是根据参数来确定模型的量化和精度设置。
-
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
: 设置数据集的前缀,如果没有设置则默认为空字符串。 -
这一大段代码根据
training_args
(训练相关参数)来确定是进行训练、验证还是测试,并对数据集进行相应的预处理。 -
if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
和return
: 如果没有进行训练、验证和预测的命令,则输出相应的信息并结束程序。
-
def preprocess_function(examples):
:定义一个预处理函数,对输入的例子进行预处理。 -
inputs = [prefix + ex[source_lang] for ex in examples[source_lang]]
: 把源语言的文本与前缀组合在一起作为输入。 -
targets = [ex[target_lang] for ex in examples[target_lang]]
:将目标语言的文本作为目标输出。 -
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True)
:使用tokenizer对输入进行处理,包括分词,添加特殊标记,转换为模型需要的输入格式等。 -
# Setup the tokenizer for targets
:设置目标的tokenizer。 -
with tokenizer.as_target_tokenizer():
:设置tokenizer为目标语言的tokenizer。 -
labels = tokenizer(targets, max_length=data_args.max_target_length, padding="max_length", truncation=True)
:对目标文本进行同样的处理。 -
model_inputs["labels"] = labels["input_ids"]
:将处理后的目标文本的输入ID作为标签。 -
return model_inputs
:返回处理后的模型输入。 -
if training_args.do_train:
:如果进行训练。 -
if "train" not in raw_datasets:
:如果训练集不在原始数据集中。 -
raise ValueError("--do_train requires a train dataset")
:抛出错误。 -
train_dataset = raw_datasets["train"]
:获取训练集。 -
if data_args.max_train_samples is not None:
:如果设置了训练样本的最大数量。 -
train_dataset = train_dataset.select(range(data_args.max_train_samples))
:从训练集中选择一定数量的样本。 -
train_dataset = train_dataset.map(
:对训练集进行映射处理,即对每个样本应用预处理函数。 -
preprocess_function,
:调用预处理函数。 -
batched=True,
:批量处理。 -
num_proc=data_args.preprocessing_num_workers,
:设置处理的进程数。 -
remove_columns=column_names,
:移除原来的列。 -
load_from_cache_file=not data_args.overwrite_cache,
:如果不覆盖缓存,则从缓存中加载。 -
desc="running tokenizer on train dataset",
:设置描述信息。 -
)
-
if training_args.do_eval:
:如果进行评估。 -
max_target_length = data_args.max_target_length
:设置目标文本的最大长度。 -
if data_args.val_max_target_length is not None:
:如果设置了验证目标文本的最大长度。 -
max_target_length = data_args.val_max_target_length
:使用设置的验证目标文本的最大长度。 -
if "validation" not in raw_datasets:
:如果验证集不在原始数据集中。 -
raise ValueError("--do_eval requires a validation dataset")
:抛出错误。 -
eval_dataset = raw_datasets["validation"]
:获取验证集。 -
if data_args.max_eval_samples is not None:
:如果设置了评估样本的最大数量。 -
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
:从验证集中选择一定数量的样本。 -
eval_dataset = eval_dataset.map(
:对验证集进行映射处理,即对每个样本应用预处理函数。 -
preprocess_function,
:调用预处理函数。 -
batched=True,
:批量处理。 -
num_proc=data_args.preprocessing_num_workers,
:设置处理的进程数。 -
remove_columns=column_names,
:移除原来的列。 -
load_from_cache_file=not data_args.overwrite_cache,
:如果不覆盖缓存,则从缓存中加载。 -
desc="running tokenizer on validation dataset",
:设置描述信息。 -
)
-
if training_args.do_predict:
:如果进行预测。 -
max_target_length = data_args.max_target_length
:设置目标文本的最大长度。 -
if data_args.predict_with_generate and data_args.test_max_target_length is not None:
:如果设置了预测目标文本的最大长度。 -
max_target_length = data_args.test_max_target_length
:使用设置的预测目标文本的最大长度。 -
if "test" not in raw_datasets:
:如果测试集不在原始数据集中。 -
raise ValueError("--do_predict requires a test dataset")
:抛出错误。 -
predict_dataset = raw_datasets["test"]
:获取测试集。 -
predict_dataset = predict_dataset.map(
:对测试集进行映射处理,即对每个样本应用预处理函数。 -
preprocess_function,
:调用预处理函数。 -
batched=True,
:批量处理。 -
num_proc=data_args.preprocessing_num_workers,
:设置处理的进程数。 -
remove_columns=column_names,
:移除原来的列。 -
load_from_cache_file=not data_args.overwrite_cache,
:如果不覆盖缓存,则从缓存中加载。 -
desc="running tokenizer on prediction dataset",
:设置描述信息。 -
)
这段代码主要是根据所提供的参数进行数据的预处理,包括训练集、验证集和测试集的处理,并根据命令行参数决定是否进行训练、验证和预测。在处理数据集时,都会调用 preprocess_function
函数对每个样本进行预处理,然后使用 map
函数应用这个预处理函数。