Baichuan2源码解析 fine-tune/fine-tune.py (一)

import os
import math
import pathlib
from typing import Optional, Dict
from dataclasses import dataclass, field
import json

import torch
from torch.utils.data import Dataset
import transformers
from transformers.training_args import TrainingArguments


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")


@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = field(default=False)


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        data_path,
        tokenizer,
        model_max_length,
        user_tokens=[195],
        assistant_tokens=[196],
    ):
        super(SupervisedDataset, self).__init__()
        self.data = json.load(open(data_path))
        self.tokenizer = tokenizer
        self.model_max_length = model_max_length
        self.user_tokens = user_tokens
        self.assistant_tokens = assistant_tokens
        self.ignore_index = -100
        item = self.preprocessing(self.data[0])
        print("input:", self.tokenizer.decode(item["input_ids"]))
        labels = []
        for id_ in item["labels"]:
            if id_ == -100:
                continue

            labels.append(id_)
        print("label:", self.tokenizer.decode(labels))

    def __len__(self):
        return len(self.data)

    def preprocessing(self, example):
        input_ids = []
        labels = []

        for message in example["conversations"]:
            from_ = message["from"]
            value = message["value"]
            value_ids = self.tokenizer.encode(value)

            if from_ == "human":
                input_ids += self.user_tokens + value_ids
                labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(
                    value_ids
                )
            else:
                input_ids += self.assistant_tokens + value_ids
                labels += [self.ignore_index] + value_ids
        input_ids.append(self.tokenizer.eos_token_id)
        labels.append(self.tokenizer.eos_token_id)
        input_ids = input_ids[: self.model_max_length]
        labels = labels[: self.model_max_length]
        input_ids += [self.tokenizer.pad_token_id] * (
            self.model_max_length - len(input_ids)
        )
        labels += [self.ignore_index] * (self.model_max_length - len(labels))
        input_ids = torch.LongTensor(input_ids)
        labels = torch.LongTensor(labels)
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
        }

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return self.preprocessing(self.data[idx])

  1. @dataclass

    • @dataclass 是一个Python装饰器,用于自动生成初始化、比较等特殊方法。它使得数据类定义变得简洁。
  2. class ModelArguments:

    • 这行定义了一个名为 ModelArguments 的类
  3. model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")

    • 定义了一个可选的字符串类型的类属性 model_name_or_path,其默认值为 "baichuan-inc/Baichuan2-7B-Base"
  4. @dataclass

    • 同第一行,用于自动生成特殊方法。
  5. class DataArguments:

    • 定义了一个名为 DataArguments 的类。
  6. data_path: str = field(default=None, metadata={"help": "Path to the training data."})

    • 定义了一个字符串类型的类属性 data_path,其默认值为 None。还为该字段添加了一些元数据描述。
  7. @dataclass

    • 同上,用于自动生成特殊方法。
  8. class TrainingArguments(transformers.TrainingArguments):

    • 定义了一个名为 TrainingArguments 的类,该类继承自 transformers.TrainingArguments
  9. cache_dir: Optional[str] = field(default=None)

    • 定义了一个可选的字符串类型的类属性 cache_dir,其默认值为 None
  10. optim: str = field(default="adamw_torch")

    • 定义了一个字符串类型的类属性 optim,其默认值为 "adamw_torch"
  11. model_max_length: int = field(default=512, metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."})

    • 定义了一个整型属性 model_max_length,其默认值为 512,并为它提供了描述。
  12. use_lora: bool = field(default=False)

    • 定义了一个布尔值属性 use_lora,默认为 False
  13. class SupervisedDataset(Dataset):

    • 定义了一个名为 SupervisedDataset 的类,继承自 Dataset 类。
  14. """Dataset for supervised fine-tuning."""

    • SupervisedDataset 类的简短描述。

22-31. def __init__(...): - 定义了 SupervisedDataset 类的初始化方法,并接收一系列参数。

  1. super(SupervisedDataset, self).__init__()

    • 调用父类(Dataset)的初始化方法。
  2. self.data = json.load(open(data_path))

    • 读取由 data_path 参数指定的JSON文件,并将其内容赋值给 self.data

25-28. 这几行将传入的参数赋值给相应的类属性。

30-32. 对第一个数据进行预处理,并打印它的输入。

34-40. 对预处理后的标签进行解码,并打印解码后的内容。

整体上,这段代码定义了与模型参数、数据参数和训练参数相关的数据类,以及一个用于监督细调的数据集类。

  1. def __len__(self):

    • 这是一个魔法方法,当你调用Python对象的len()函数时,它实际上调用的是这个方法。
  2. return len(self.data)

    • 返回数据集self.data)的长度。
  3. def preprocessing(self, example):

    • 这是一个名为preprocessing的方法,它将对单个示例进行预处理。
  4. input_ids = []

    • 初始化一个空列表来存储输入的ID。
  5. labels = []

    • 初始化一个空列表来存储标签。
  6. for message in example["conversations"]:

    • 开始循环处理example中的每个消息。
  7. from_ = message["from"]

    • 获取消息的发送者。
  8. value = message["value"]

    • 获取消息的值或内容。
  9. value_ids = self.tokenizer.encode(value)

    • 使用tokenizer将消息内容编码为token ID。
  10. if from_ == "human":

    • 检查消息的发送者是否为“human”。
  11. input_ids += self.user_tokens + value_ids

    • 如果消息来自“human”,则input_ids列表中追加特定于用户的token,然后追加消息的token ID。
  12. labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(value_ids)

    • 添加结束符token ID到标签,并为消息的每个token添加忽略标签。
  13. else:

    • 如果消息不是来自“human”。
  14. input_ids += self.assistant_tokens + value_ids

    • input_ids列表中追加特定于助手的token,然后追加消息的token ID。
  15. labels += [self.ignore_index] + value_ids

    • 添加忽略标签和消息的token ID到标签。
  16. input_ids.append(self.tokenizer.eos_token_id)

    • input_ids列表的末尾追加结束符token ID。
  17. labels.append(self.tokenizer.eos_token_id)

    • labels列表的末尾追加结束符token ID。

23-24. 对input_idslabels进行截断,使其长度不超过model_max_length

26-27. 为input_idslabels填充token,使其长度达到model_max_length

29-30. 将input_idslabels转换为PyTorch的LongTensor数据类型。

  1. attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
    • 创建一个注意力掩码,标记input_ids中哪些token不是填充token。

33-36. 返回一个字典,包含预处理后的input_idslabelsattention_mask

  1. def __getitem__(self, idx) -> Dict[str, torch.Tensor]:

    • 这是一个魔法方法,当你尝试通过索引从数据集中获取项目时,它实际上调用的是这个方法。
  2. return self.preprocessing(self.data[idx])

    • 通过索引self.data中获取一个示例,然后使preprocessing方法对其进行预处理,最后返回预处理后的结果。

这段代码的核心是preprocessing方法,它对对话数据进行预处理,将对话消息转换为模型可以接受的格式。

# @dataclass 是一个Python装饰器,用于自动生成初始化、比较等特殊方法。它使得数据类定义变得简洁。
@dataclass
# 这行定义了一个名为 `ModelArguments` 的类。
class ModelArguments:
    # 定义了一个可选的字符串类型的类属性 `model_name_or_path`,其默认值为 `"baichuan-inc/Baichuan2-7B-Base"`。
    model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")

# 同第一行,用于自动生成特殊方法。
@dataclass
# 定义了一个名为 `DataArguments` 的类。
class DataArguments:
    # 定义了一个字符串类型的类属性 `data_path`,其默认值为 `None`。还为该字段添加了一些元数据描述。
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )

# 同上,用于自动生成特殊方法。
@dataclass
# 定义了一个名为 `TrainingArguments` 的类,该类继承自 `transformers.TrainingArguments`。
class TrainingArguments(transformers.TrainingArguments):
    # 定义了一个可选的字符串类型的类属性 `cache_dir`,其默认值为 `None`。
    cache_dir: Optional[str] = field(default=None)
    # 定义了一个字符串类型的类属性 `optim`,其默认值为 `"adamw_torch"`。
    optim: str = field(default="adamw_torch")
    # 定义了一个整型属性 `model_max_length`,其默认值为 `512`,并为它提供了描述。
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    # 定义了一个布尔值属性 `use_lora`,默认为 `False`。
    use_lora: bool = field(default=False)

# 定义了一个名为 `SupervisedDataset` 的类,继承自 `Dataset` 类。
class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    # 定义了 `SupervisedDataset` 类的初始化方法,并接收一系列参数。
    def __init__(
        self,
        data_path,
        tokenizer,
        model_max_length,
        user_tokens=[195],
        assistant_tokens=[196],
    ):
        # 调用父类(Dataset)的初始化方法。
        super(SupervisedDataset, self).__init__()
        # 读取由 `data_path` 参数指定的JSON文件,并将其内容赋值给 `self.data`。
        self.data = json.load(open(data_path))
        # 这几行将传入的参数赋值给相应的类属性。
        self.tokenizer = tokenizer
        self.model_max_length = model_max_length
        self.user_tokens = user_tokens
        self.assistant_tokens = assistant_tokens
        self.ignore_index = -100
        # 对第一个数据进行预处理,并打印它的输入。
        item = self.preprocessing(self.data[0])
        print("input:", self.tokenizer.decode(item["input_ids"]))
        labels = []
        # 对预处理后的标签进行解码,并打印解码后的内容。
        for id_ in item["labels"]:
            if id_ == -100:
                continue

            labels.append(id_)
        print("label:", self.tokenizer.decode(labels))

    # 定义魔法方法,返回数据集的大小。
    def __len__(self):
        return len(self.data)

    # 定义预处理方法,对单个示例进行预处理。
    def preprocessing(self, example):
        # 初始化输入ID的空列表。
        input_ids = []
        # 初始化标签的空列表。
        labels = []

        # 遍历每个对话中的消息。
        for message in example["conversations"]:
            # 获取消息的发送者。
            from_ = message["from"]
            # 获取消息的内容。
            value = message["value"]
            # 使用tokenizer对消息内容进行编码。
            value_ids = self.tokenizer.encode(value)

            # 如果消息来自人类用户。
            if from_ == "human":
                # 在输入ID中添加用户特定的token和消息token。
                input_ids += self.user_tokens + value_ids
                # 在标签中添加结束符和忽略标签。
                labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(
                    value_ids
                )
            else:
                # 如果消息来自助手,添加助手特定的token和消息token。
                input_ids += self.assistant_tokens + value_ids
                # 在标签中添加忽略标签和消息token。
                labels += [self.ignore_index] + value_ids

        # 在输入ID和标签的末尾都追加结束符token。
        input_ids.append(self.tokenizer.eos_token_id)
        labels.append(self.tokenizer.eos_token_id)

        # 对输入ID和标签进行截断。
        input_ids = input_ids[: self.model_max_length]
        labels = labels[: self.model_max_length]

        # 为输入ID和标签填充token。
        input_ids += [self.tokenizer.pad_token_id] * (
            self.model_max_length - len(input_ids)
        )
        labels += [self.ignore_index] * (self.model_max_length - len(labels))

        # 转换输入ID和标签为PyTorch的LongTensor。
        input_ids = torch.LongTensor(input_ids)
        labels = torch.LongTensor(labels)

        # 创建注意力掩码。
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)

        # 返回预处理后的结果。
        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
        }

    # 定义魔法方法,允许使用索引从数据集中获取示例。
    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return self.preprocessing(self.data[idx])

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/132783096
今日推荐