Baichuan2 source code analysis fine-tune/fine-tune.py (1)

 

import os
import math
import pathlib
from typing import Optional, Dict
from dataclasses import dataclass, field
import json

import torch
from torch.utils.data import Dataset
import transformers
from transformers.training_args import TrainingArguments


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")


@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = field(default=False)


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        data_path,
        tokenizer,
        model_max_length,
        user_tokens=[195],
        assistant_tokens=[196],
    ):
        super(SupervisedDataset, self).__init__()
        self.data = json.load(open(data_path))
        self.tokenizer = tokenizer
        self.model_max_length = model_max_length
        self.user_tokens = user_tokens
        self.assistant_tokens = assistant_tokens
        self.ignore_index = -100
        item = self.preprocessing(self.data[0])
        print("input:", self.tokenizer.decode(item["input_ids"]))
        labels = []
        for id_ in item["labels"]:
            if id_ == -100:
                continue

            labels.append(id_)
        print("label:", self.tokenizer.decode(labels))

    def __len__(self):
        return len(self.data)

    def preprocessing(self, example):
        input_ids = []
        labels = []

        for message in example["conversations"]:
            from_ = message["from"]
            value = message["value"]
            value_ids = self.tokenizer.encode(value)

            if from_ == "human":
                input_ids += self.user_tokens + value_ids
                labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(
                    value_ids
                )
            else:
                input_ids += self.assistant_tokens + value_ids
                labels += [self.ignore_index] + value_ids
        input_ids.append(self.tokenizer.eos_token_id)
        labels.append(self.tokenizer.eos_token_id)
        input_ids = input_ids[: self.model_max_length]
        labels = labels[: self.model_max_length]
        input_ids += [self.tokenizer.pad_token_id] * (
            self.model_max_length - len(input_ids)
        )
        labels += [self.ignore_index] * (self.model_max_length - len(labels))
        input_ids = torch.LongTensor(input_ids)
        labels = torch.LongTensor(labels)
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
        }

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return self.preprocessing(self.data[idx])

  1. @dataclass

    • @dataclassIt is a Python decorator used to automatically generate special methods such as initialization and comparison . It makes data class definition concise.
  2. class ModelArguments:

    • This line defines a class called . ModelArguments
  3. model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")

    • Defines an optional class attribute of type string model_name_or_pathwith a default value of "baichuan-inc/Baichuan2-7B-Base".
  4. @dataclass

    • Same as the first line, used to automatically generate special methods.
  5. class DataArguments:

    • DataArgumentsA class named is defined .
  6. data_path: str = field(default=None, metadata={"help": "Path to the training data."})

    • Defines a class attribute of type string data_pathwith a default value of None. Also added some metadata description for this field.
  7. @dataclass

    • Same as above, for automatically generating special methods.
  8. class TrainingArguments(transformers.TrainingArguments):

    • TrainingArgumentsA class named is defined , which inherits from transformers.TrainingArguments.
  9. cache_dir: Optional[str] = field(default=None)

    • Defines an optional class attribute of type string cache_dirwith a default value of None.
  10. optim: str = field(default="adamw_torch")

    • Defines a class attribute of type string optimwith a default value of "adamw_torch".
  11. model_max_length: int = field(default=512, metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."})

    • An integer property is defined model_max_lengthwith a default value of 512, and a description is provided for it.
  12. use_lora: bool = field(default=False)

    • Defines a boolean property use_lora, which defaults to False.
  13. class SupervisedDataset(Dataset):

    • SupervisedDatasetA class named is defined , which inherits from Datasetthe class.
  14. """Dataset for supervised fine-tuning."""

    • A short description of SupervisedDatasetthe class.

22-31. def __init__(...):- Defines SupervisedDatasetthe initialization method of the class and receives a series of parameters.

  1. super(SupervisedDataset, self).__init__()

    • Call the initialization method of the parent class (Dataset).
  2. self.data = json.load(open(data_path))

    • Reads data_paththe JSON file specified by the parameter and assigns its contents to self.data.

25-28. These lines assign the incoming parameters to the corresponding class attributes.

30-32. Preprocess the first data and print its input.

34-40. Decode the preprocessed tag and print the decoded content.

Overall, this code defines data classes related to model parameters, data parameters, and training parameters, as well as a dataset class for supervised fine-tuning.

  1. def __len__(self):

    • This is a magic method that is actually called when you call a function on a Python object .len()
  2. return len(self.data)

    • Returns the length of the data set ( self.data).
  3. def preprocessing(self, example):

    • This is a preprocessingmethod called which will preprocess a single example.
  4. input_ids = []

    • Initialize an empty list to store the input IDs.
  5. labels = []

    • Initialize an empty list to store tags.
  6. for message in example["conversations"]:

    • exampleBegins each message in the loop .
  7. from_ = message["from"]

    • Get the sender of the message.
  8. value = message["value"]

    • Get the value or content of the message.
  9. value_ids = self.tokenizer.encode(value)

    • Use a tokenizer to encode the message content into a token ID.
  10. if from_ == "human":

    • Check if the sender of the message is "human ".
  11. input_ids += self.user_tokens + value_ids

    • If the message is from "human", append the user-specific token to the list , followed by the token ID of the message.input_ids
  12. labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(value_ids)

    • Add the terminator token ID to the tag and add the ignore tag for each token of the message.
  13. else:

    • If the message is not from "human".
  14. input_ids += self.assistant_tokens + value_ids

    • input_idsAppend the assistant-specific token to the list, followed by the token ID of the message .
  15. labels += [self.ignore_index] + value_ids

    • Add the token ID of the ignore tag and message to the tag.
  16. input_ids.append(self.tokenizer.eos_token_id)

    • input_idsAppend the terminator token ID at the end of the list.
  17. labels.append(self.tokenizer.eos_token_id)

    • labelsAppend the terminator token ID at the end of the list.

23-24. input_idsTruncate labelssum so that its length does not exceed model_max_length.

26-27. Padding token for input_idsand labelsso that its length reaches model_max_length.

29-30. Convert input_idssum labelsto PyTorch’s LongTensor data type.

  1. attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
    • Create an attention mask to mark input_idswhich tokens are not padding tokens.

33-36. Return a dictionary containing preprocessed input_ids, labelsand attention_mask.

  1. def __getitem__(self, idx) -> Dict[str, torch.Tensor]:

    • This is a magic method that is actually called when you try to get an item from the dataset by index.
  2. return self.preprocessing(self.data[idx])

    • Get an example from it by index , preprocess it using a method , and finally return the preprocessed result.self.datapreprocessing

The core of this code is preprocessingthe method, which preprocesses the conversation data and converts the conversation messages into a format acceptable to the model.

# @dataclass 是一个Python装饰器,用于自动生成初始化、比较等特殊方法。它使得数据类定义变得简洁。
@dataclass
# 这行定义了一个名为 `ModelArguments` 的类。
class ModelArguments:
    # 定义了一个可选的字符串类型的类属性 `model_name_or_path`,其默认值为 `"baichuan-inc/Baichuan2-7B-Base"`。
    model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")

# 同第一行,用于自动生成特殊方法。
@dataclass
# 定义了一个名为 `DataArguments` 的类。
class DataArguments:
    # 定义了一个字符串类型的类属性 `data_path`,其默认值为 `None`。还为该字段添加了一些元数据描述。
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )

# 同上,用于自动生成特殊方法。
@dataclass
# 定义了一个名为 `TrainingArguments` 的类,该类继承自 `transformers.TrainingArguments`。
class TrainingArguments(transformers.TrainingArguments):
    # 定义了一个可选的字符串类型的类属性 `cache_dir`,其默认值为 `None`。
    cache_dir: Optional[str] = field(default=None)
    # 定义了一个字符串类型的类属性 `optim`,其默认值为 `"adamw_torch"`。
    optim: str = field(default="adamw_torch")
    # 定义了一个整型属性 `model_max_length`,其默认值为 `512`,并为它提供了描述。
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    # 定义了一个布尔值属性 `use_lora`,默认为 `False`。
    use_lora: bool = field(default=False)

# 定义了一个名为 `SupervisedDataset` 的类,继承自 `Dataset` 类。
class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    # 定义了 `SupervisedDataset` 类的初始化方法,并接收一系列参数。
    def __init__(
        self,
        data_path,
        tokenizer,
        model_max_length,
        user_tokens=[195],
        assistant_tokens=[196],
    ):
        # 调用父类(Dataset)的初始化方法。
        super(SupervisedDataset, self).__init__()
        # 读取由 `data_path` 参数指定的JSON文件,并将其内容赋值给 `self.data`。
        self.data = json.load(open(data_path))
        # 这几行将传入的参数赋值给相应的类属性。
        self.tokenizer = tokenizer
        self.model_max_length = model_max_length
        self.user_tokens = user_tokens
        self.assistant_tokens = assistant_tokens
        self.ignore_index = -100
        # 对第一个数据进行预处理,并打印它的输入。
        item = self.preprocessing(self.data[0])
        print("input:", self.tokenizer.decode(item["input_ids"]))
        labels = []
        # 对预处理后的标签进行解码,并打印解码后的内容。
        for id_ in item["labels"]:
            if id_ == -100:
                continue

            labels.append(id_)
        print("label:", self.tokenizer.decode(labels))

    # 定义魔法方法,返回数据集的大小。
    def __len__(self):
        return len(self.data)

    # 定义预处理方法,对单个示例进行预处理。
    def preprocessing(self, example):
        # 初始化输入ID的空列表。
        input_ids = []
        # 初始化标签的空列表。
        labels = []

        # 遍历每个对话中的消息。
        for message in example["conversations"]:
            # 获取消息的发送者。
            from_ = message["from"]
            # 获取消息的内容。
            value = message["value"]
            # 使用tokenizer对消息内容进行编码。
            value_ids = self.tokenizer.encode(value)

            # 如果消息来自人类用户。
            if from_ == "human":
                # 在输入ID中添加用户特定的token和消息token。
                input_ids += self.user_tokens + value_ids
                # 在标签中添加结束符和忽略标签。
                labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(
                    value_ids
                )
            else:
                # 如果消息来自助手,添加助手特定的token和消息token。
                input_ids += self.assistant_tokens + value_ids
                # 在标签中添加忽略标签和消息token。
                labels += [self.ignore_index] + value_ids

        # 在输入ID和标签的末尾都追加结束符token。
        input_ids.append(self.tokenizer.eos_token_id)
        labels.append(self.tokenizer.eos_token_id)

        # 对输入ID和标签进行截断。
        input_ids = input_ids[: self.model_max_length]
        labels = labels[: self.model_max_length]

        # 为输入ID和标签填充token。
        input_ids += [self.tokenizer.pad_token_id] * (
            self.model_max_length - len(input_ids)
        )
        labels += [self.ignore_index] * (self.model_max_length - len(labels))

        # 转换输入ID和标签为PyTorch的LongTensor。
        input_ids = torch.LongTensor(input_ids)
        labels = torch.LongTensor(labels)

        # 创建注意力掩码。
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)

        # 返回预处理后的结果。
        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
        }

    # 定义魔法方法,允许使用索引从数据集中获取示例。
    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return self.preprocessing(self.data[idx])

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/132783096