model (required): 待训练的模型,必须是 PyTorch 模型。
args (required): TrainingArguments 对象,包含训练和评估过程的参数,例如训练周期数、学习率、批量大小等。
train_dataset (optional): 训练数据集,可以是 Dataset 对象或一个 list。
eval_dataset (optional): 验证数据集,可以是 Dataset 对象或一个 list。
data_collator (optional): 用于将训练数据转换为批量的函数,如果未指定,则默认使用默认的数据收集器(DataCollatorWithPadding)。
tokenizer (optional): 用于编码文本的 tokenizer 对象,如果未指定,则默认使用 BERT tokenizer。
model_init (optional): 用于初始化模型的函数,如果未指定,则默认使用默认的模型初始化函数。
compute_metrics (optional): 用于计算模型性能的函数,例如准确率、F1 值等。
callbacks (optional): 用于在训练期间执行回调操作的列表,例如 EarlyStoppingCallback、ModelCheckpointCallback 等。
optimizer (optional): 用于优化模型参数的优化器,如果未指定,则默认使用 AdamW。
scheduler (optional): 用于调整学习率的学习率调度器,如果未指定,则默认使用 LinearWarmupCosineAnnealingLR。
device (optional): 指定训练所用的设备,可以是字符串('cpu' 或 'cuda')或 torch.device 对象。
gradient_accumulation_steps (optional): 梯度积累的步数,默认为 1,表示不进行梯度积累。
fp16 (optional): 是否使用混合精度训练,默认为 False。
fp16_opt_level (optional): 混合精度训练的优化级别,默认为 'O1'。
dataloader_num_workers (optional): DataLoader 使用的 worker 数量,默认为 0,表示使用主进程加载数据。
past_index (optional): 指定模型是否使用过去状态,例如 GPT-2 模型会使用过去状态,BERT 模型不会使用。
label_smoother (optional): 用于平滑标签的 LabelSmoothingCrossEntropy 对象。
train_sampler (optional): 用于对训练数据进行采样的 Sampler 对象。
eval_sampler (optional): 用于对验证数据进行采样的 Sampler 对象。
prediction_loss_only (optional): 是否仅计算预测损失,而不计算正则化损失。默认为 False。
huggingface ,Trainer() 函数是 Transformers 库中用于训练和评估模型的主要接口,Trainer()函数的参数如下:
猜你喜欢
转载自blog.csdn.net/weixin_44002458/article/details/130138252
今日推荐
周排行