import os
import math
import pathlib
from typing import Optional, Dict
from dataclasses import dataclass, field
import json
import torch
from torch.utils.data import Dataset
import transformers
from transformers.training_args import TrainingArguments
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
use_lora: bool = field(default=False)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_path,
tokenizer,
model_max_length,
user_tokens=[195],
assistant_tokens=[196],
):
super(SupervisedDataset, self).__init__()
self.data = json.load(open(data_path))
self.tokenizer = tokenizer
self.model_max_length = model_max_length
self.user_tokens = user_tokens
self.assistant_tokens = assistant_tokens
self.ignore_index = -100
item = self.preprocessing(self.data[0])
print("input:", self.tokenizer.decode(item["input_ids"]))
labels = []
for id_ in item["labels"]:
if id_ == -100:
continue
labels.append(id_)
print("label:", self.tokenizer.decode(labels))
def __len__(self):
return len(self.data)
def preprocessing(self, example):
input_ids = []
labels = []
for message in example["conversations"]:
from_ = message["from"]
value = message["value"]
value_ids = self.tokenizer.encode(value)
if from_ == "human":
input_ids += self.user_tokens + value_ids
labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(
value_ids
)
else:
input_ids += self.assistant_tokens + value_ids
labels += [self.ignore_index] + value_ids
input_ids.append(self.tokenizer.eos_token_id)
labels.append(self.tokenizer.eos_token_id)
input_ids = input_ids[: self.model_max_length]
labels = labels[: self.model_max_length]
input_ids += [self.tokenizer.pad_token_id] * (
self.model_max_length - len(input_ids)
)
labels += [self.ignore_index] * (self.model_max_length - len(labels))
input_ids = torch.LongTensor(input_ids)
labels = torch.LongTensor(labels)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
return self.preprocessing(self.data[idx])
-
@dataclass
@dataclass
It is a Python decorator used to automatically generate special methods such as initialization and comparison . It makes data class definition concise.
-
class ModelArguments:
- This line defines a class called .
ModelArguments
- This line defines a class called .
-
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
- Defines an optional class attribute of type string
model_name_or_path
with a default value of"baichuan-inc/Baichuan2-7B-Base"
.
- Defines an optional class attribute of type string
-
@dataclass
- Same as the first line, used to automatically generate special methods.
-
class DataArguments:
DataArguments
A class named is defined .
-
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
- Defines a class attribute of type string
data_path
with a default value ofNone
. Also added some metadata description for this field.
- Defines a class attribute of type string
-
@dataclass
- Same as above, for automatically generating special methods.
-
class TrainingArguments(transformers.TrainingArguments):
TrainingArguments
A class named is defined , which inherits fromtransformers.TrainingArguments
.
-
cache_dir: Optional[str] = field(default=None)
- Defines an optional class attribute of type string
cache_dir
with a default value ofNone
.
- Defines an optional class attribute of type string
-
optim: str = field(default="adamw_torch")
- Defines a class attribute of type string
optim
with a default value of"adamw_torch"
.
- Defines a class attribute of type string
-
model_max_length: int = field(default=512, metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."})
- An integer property is defined
model_max_length
with a default value of512
, and a description is provided for it.
- An integer property is defined
-
use_lora: bool = field(default=False)
- Defines a boolean property
use_lora
, which defaults toFalse
.
- Defines a boolean property
-
class SupervisedDataset(Dataset):
SupervisedDataset
A class named is defined , which inherits fromDataset
the class.
-
"""Dataset for supervised fine-tuning."""
- A short description of
SupervisedDataset
the class.
- A short description of
22-31. def __init__(...):
- Defines SupervisedDataset
the initialization method of the class and receives a series of parameters.
-
super(SupervisedDataset, self).__init__()
- Call the initialization method of the parent class (Dataset).
-
self.data = json.load(open(data_path))
- Reads
data_path
the JSON file specified by the parameter and assigns its contents toself.data
.
- Reads
25-28. These lines assign the incoming parameters to the corresponding class attributes.
30-32. Preprocess the first data and print its input.
34-40. Decode the preprocessed tag and print the decoded content.
Overall, this code defines data classes related to model parameters, data parameters, and training parameters, as well as a dataset class for supervised fine-tuning.
-
def __len__(self):
- This is a magic method that is actually called when you call a function on a Python object .
len()
- This is a magic method that is actually called when you call a function on a Python object .
-
return len(self.data)
- Returns the length of the data set (
self.data
).
- Returns the length of the data set (
-
def preprocessing(self, example):
- This is a
preprocessing
method called which will preprocess a single example.
- This is a
-
input_ids = []
- Initialize an empty list to store the input IDs.
-
labels = []
- Initialize an empty list to store tags.
-
for message in example["conversations"]:
example
Begins each message in the loop .
-
from_ = message["from"]
- Get the sender of the message.
-
value = message["value"]
- Get the value or content of the message.
-
value_ids = self.tokenizer.encode(value)
- Use a tokenizer to encode the message content into a token ID.
-
if from_ == "human":
- Check if the sender of the message is "human ".
-
input_ids += self.user_tokens + value_ids
- If the message is from "human", append the user-specific token to the list , followed by the token ID of the message.
input_ids
- If the message is from "human", append the user-specific token to the list , followed by the token ID of the message.
-
labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(value_ids)
- Add the terminator token ID to the tag and add the ignore tag for each token of the message.
-
else:
- If the message is not from "human".
-
input_ids += self.assistant_tokens + value_ids
input_ids
Append the assistant-specific token to the list, followed by the token ID of the message .
-
labels += [self.ignore_index] + value_ids
- Add the token ID of the ignore tag and message to the tag.
-
input_ids.append(self.tokenizer.eos_token_id)
input_ids
Append the terminator token ID at the end of the list.
-
labels.append(self.tokenizer.eos_token_id)
labels
Append the terminator token ID at the end of the list.
23-24. input_ids
Truncate labels
sum so that its length does not exceed model_max_length
.
26-27. Padding token for input_ids
and labels
so that its length reaches model_max_length
.
29-30. Convert input_ids
sum labels
to PyTorch’s LongTensor data type.
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
- Create an attention mask to mark
input_ids
which tokens are not padding tokens.
- Create an attention mask to mark
33-36. Return a dictionary containing preprocessed input_ids
, labels
and attention_mask
.
-
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
- This is a magic method that is actually called when you try to get an item from the dataset by index.
-
return self.preprocessing(self.data[idx])
- Get an example from it by index , preprocess it using a method , and finally return the preprocessed result.
self.data
preprocessing
- Get an example from it by index , preprocess it using a method , and finally return the preprocessed result.
The core of this code is preprocessing
the method, which preprocesses the conversation data and converts the conversation messages into a format acceptable to the model.
# @dataclass 是一个Python装饰器,用于自动生成初始化、比较等特殊方法。它使得数据类定义变得简洁。
@dataclass
# 这行定义了一个名为 `ModelArguments` 的类。
class ModelArguments:
# 定义了一个可选的字符串类型的类属性 `model_name_or_path`,其默认值为 `"baichuan-inc/Baichuan2-7B-Base"`。
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
# 同第一行,用于自动生成特殊方法。
@dataclass
# 定义了一个名为 `DataArguments` 的类。
class DataArguments:
# 定义了一个字符串类型的类属性 `data_path`,其默认值为 `None`。还为该字段添加了一些元数据描述。
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
# 同上,用于自动生成特殊方法。
@dataclass
# 定义了一个名为 `TrainingArguments` 的类,该类继承自 `transformers.TrainingArguments`。
class TrainingArguments(transformers.TrainingArguments):
# 定义了一个可选的字符串类型的类属性 `cache_dir`,其默认值为 `None`。
cache_dir: Optional[str] = field(default=None)
# 定义了一个字符串类型的类属性 `optim`,其默认值为 `"adamw_torch"`。
optim: str = field(default="adamw_torch")
# 定义了一个整型属性 `model_max_length`,其默认值为 `512`,并为它提供了描述。
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
# 定义了一个布尔值属性 `use_lora`,默认为 `False`。
use_lora: bool = field(default=False)
# 定义了一个名为 `SupervisedDataset` 的类,继承自 `Dataset` 类。
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
# 定义了 `SupervisedDataset` 类的初始化方法,并接收一系列参数。
def __init__(
self,
data_path,
tokenizer,
model_max_length,
user_tokens=[195],
assistant_tokens=[196],
):
# 调用父类(Dataset)的初始化方法。
super(SupervisedDataset, self).__init__()
# 读取由 `data_path` 参数指定的JSON文件,并将其内容赋值给 `self.data`。
self.data = json.load(open(data_path))
# 这几行将传入的参数赋值给相应的类属性。
self.tokenizer = tokenizer
self.model_max_length = model_max_length
self.user_tokens = user_tokens
self.assistant_tokens = assistant_tokens
self.ignore_index = -100
# 对第一个数据进行预处理,并打印它的输入。
item = self.preprocessing(self.data[0])
print("input:", self.tokenizer.decode(item["input_ids"]))
labels = []
# 对预处理后的标签进行解码,并打印解码后的内容。
for id_ in item["labels"]:
if id_ == -100:
continue
labels.append(id_)
print("label:", self.tokenizer.decode(labels))
# 定义魔法方法,返回数据集的大小。
def __len__(self):
return len(self.data)
# 定义预处理方法,对单个示例进行预处理。
def preprocessing(self, example):
# 初始化输入ID的空列表。
input_ids = []
# 初始化标签的空列表。
labels = []
# 遍历每个对话中的消息。
for message in example["conversations"]:
# 获取消息的发送者。
from_ = message["from"]
# 获取消息的内容。
value = message["value"]
# 使用tokenizer对消息内容进行编码。
value_ids = self.tokenizer.encode(value)
# 如果消息来自人类用户。
if from_ == "human":
# 在输入ID中添加用户特定的token和消息token。
input_ids += self.user_tokens + value_ids
# 在标签中添加结束符和忽略标签。
labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(
value_ids
)
else:
# 如果消息来自助手,添加助手特定的token和消息token。
input_ids += self.assistant_tokens + value_ids
# 在标签中添加忽略标签和消息token。
labels += [self.ignore_index] + value_ids
# 在输入ID和标签的末尾都追加结束符token。
input_ids.append(self.tokenizer.eos_token_id)
labels.append(self.tokenizer.eos_token_id)
# 对输入ID和标签进行截断。
input_ids = input_ids[: self.model_max_length]
labels = labels[: self.model_max_length]
# 为输入ID和标签填充token。
input_ids += [self.tokenizer.pad_token_id] * (
self.model_max_length - len(input_ids)
)
labels += [self.ignore_index] * (self.model_max_length - len(labels))
# 转换输入ID和标签为PyTorch的LongTensor。
input_ids = torch.LongTensor(input_ids)
labels = torch.LongTensor(labels)
# 创建注意力掩码。
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
# 返回预处理后的结果。
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
# 定义魔法方法,允许使用索引从数据集中获取示例。
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
return self.preprocessing(self.data[idx])