Source code analysis ChatGLM Efficient Tuning/src/utils/config.py

import os
import json
import torch
from typing import Any, Dict, List, Literal, Optional
from dataclasses import asdict, dataclass, field


@dataclass
class DatasetAttr:

    load_from: str
    dataset_name: Optional[str] = None
    dataset_sha1: Optional[str] = None

    def __repr__(self) -> str:
        return self.dataset_name

    def __post_init__(self):
        self.prompt_column = "instruction"
        self.query_column = "input"
        self.response_column = "output"
        self.history_column = None


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
    """
    model_name_or_path: Optional[str] = field(
        default="THUDM/chatglm-6b",
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
    )
    use_v2: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use ChatGLM2 or not."}
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={"help": "Pretrained config name or path if not the same as model_name."}
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
    )
    use_fast_tokenizer: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
    )
    model_revision: Optional[str] = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
    )
    use_auth_token: Optional[bool] = field(
        default=False,
        metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
    )
    quantization_bit: Optional[int] = field(
        default=None,
        metadata={"help": "The number of bits to quantize the model."}
    )
    quantization_type: Optional[Literal["fp4", "nf4"]] = field(
        default="nf4",
        metadata={"help": "Quantization data type to use in int4 training."}
    )
    double_quantization: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use double quantization in int4 training or not."}
    )
    compute_dtype: Optional[torch.dtype] = field(
        default=None,
        metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
    )
    checkpoint_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
    )
    reward_model: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
    )
    resume_lora_training: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
    )
    plot_loss: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
    )

    def __post_init__(self):
        if self.use_v2 and self.model_name_or_path == "THUDM/chatglm-6b":
            self.model_name_or_path = "THUDM/chatglm2-6b"

        if self.checkpoint_dir is not None: # support merging lora weights
            self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]

        if self.quantization_bit is not None:
            assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."

This code defines two data classes: DatasetAttrand ModelArguments.

  1. from typing import Any, Dict, List, Literal, Optional: Import some type annotation tools from typingthe module, including Any(any type), Dict(dictionary type), List(list type), Literal(literal type, possible values ​​can be specified) and Optional(optional, value may be None).

  2. from dataclasses import asdict, dataclass, fielddataclasses: Import some data class-related tools from the module, including asdict(convert data class objects to dictionaries), dataclass(decorators for defining data classes), and field(for defining fields of data classes).

  3. Define DatasetAttrclass:

    • @dataclassis a decorator that indicates DatasetAttrthat is a data class.
    • This class has three fields: load_from, dataset_nameand dataset_sha1, respectively indicate where to load the dataset, the name of the dataset and the sha1 value of the dataset. The last two fields are optional and may be None.
    • Override __repr__the method, when printing the instance of this class, it will return dataset_namethe value of the field.
    • In __post_init__the method, four fields are defined: prompt_column, query_column, response_columnand history_column, these fields are automatically assigned after instantiating the object.
  4. Define ModelArgumentsclass:

    • This class has many fields, and each field is fielddefined with a function, which can define the default value of the field, as well as some metadata, which may be helpful to understand the meaning of the field.
    • This class contains some parameters that may be useful for fine-tuning the model, such as model name/path, configuration name, tokenizer name, cache directory, model version, whether to use fast tokenizer, whether to use validation token, quantization bits, quantization type , whether to use double quantization, checkpoint directory, reward model path, whether to continue training, whether to draw a loss map, etc.
    • In __post_init__the method, it may be modified depending on the values ​​of use_v2and . And, if is not None, it will be split according to . Also, if is not None, asserts that the value of must be 4 or 8.model_name_or_pathmodel_name_or_pathcheckpoint_dir,checkpoint_dirquantization_bitquantization_bit

ModelArgumentsEach field in the class is explained in detail next :

  • model_name_or_path: Indicates the name or path of the pre-trained model. Defaults to "THUDM/chatglm-6b", which may be an identifier for a pretrained model. In __post_init__the method, if use_v2is true and model_name_or_pathis "THUDM/chatglm-6b", it will be model_name_or_pathchanged to "THUDM/chatglm2-6b".

  • use_v2: Indicates whether to use ChatGLM2. Defaults to false.

  • config_name: The name or path of the pre-training configuration. model_name_or_pathThis field is required if different from .

  • tokenizer_name: The name or path of the pretrained tokenizer. model_name_or_pathThis field is required if different from .

  • cache_dir: location to store pretrained models downloaded from huggingface.co.

  • use_fast_tokenizer: Indicates whether to use the quick tokenizer. Defaults to true.

  • model_revision: The specific model version to use (can be a branch name, tag name, or commit id). Defaults to "main".

  • use_auth_token: Whether to use huggingface-cli logintokens generated at runtime . Defaults to false.

  • quantization_bit: Quantize the number of bits of the model. In __post_init__the method, if quantization_bitis not None, it is asserted quantization_bitthat the value of must be 4 or 8.

  • quantization_type: Quantized data type for int4 training. Defaults to "nf4".

  • double_quantization: Whether to use double quantization in int4 training. Defaults to true.

  • compute_dtype: Data type used for quantization configuration. This parameter usually does not need to be specified manually.

  • checkpoint_dir: Path to the directory containing model checkpoints and configurations. In __post_init__the method, if checkpoint_diris not None, it will be ,split according to checkpoint_dir.

  • reward_model: Path to the directory containing checkpoints for the reward model.

  • resume_lora_training: Indicates whether to continue training from the last LoRA weights, or create new weights after merging weights. Defaults to true.

  • plot_loss: Indicates whether to plot the training loss after fine-tuning. Defaults to false.

ModelArgumentsMost of these fields have default values, but custom values ​​can also be provided when the object is created . They can control the fine-tuning process of the model, including which model to use, how to quantize the model, and whether to plot the training loss, etc.

@dataclass please explain in detail

In Python, @dataclassis a decorator, which is used to automatically add special methods (such as __init__, __repr__etc.) to user-defined classes in order to simplify the definition of the class. @dataclassModule from Python dataclasses, available in Python 3.7 and later.

Classes that use @dataclassdecorators are called data classes. Data classes are usually used to store some value, and have some special methods that are automatically generated.

For example, we can define a data class to represent the coordinates of a point:

 

pythonCopy code

from dataclasses import dataclass

@dataclass
class Point:
    x: int
    y: int

from dataclasses import dataclass @dataclass class Point: x: int y: int

In this example, Pointthe class has two fields: xand y, which are of type int.

@dataclassThe decorator automatically adds a __init__method to Pointthe class that __init__takes two arguments ( xand y) and uses them to initialize Pointthe object. PointTherefore, we can create an object like this :

 

pythonCopy code

p = Point(1, 2)

In addition, @dataclassthe decorator will automatically add some other special methods to the class, including:

  • __repr__: Returns a string representing an object , eg Point(x=1, y=2).
  • __eq__: Determine whether two objects are equal. This method assumes by default that objects are equal if they are of the same type and all fields are equal.
  • __hash__: If the data class is immutable (i.e. all fields are immutable), then a __hash__method is automatically added. This method is used to calculate the hash value of the object, so that the object can be used as a dictionary key.

In the code you provided, @dataclassdecorators are used to define DatasetAttrand ModelArgumentsthese two data classes. This allows you to easily create objects of these classes, and to easily print representations of objects, or compare objects for equality, etc.

__str__and __repr__the similarities and differences?

representation of a string

We all know that Python's built-in functions repr()can express objects in the form of strings , which is convenient for us to identify. This is the " string representation" . repr()It is through __repr__this special method that a string representation . If it is not implemented , when we print a vector instance__repr__ in the console , the string we get may be <Vector object at 0x10a514f98>.

__repr__is a Python special method known as the "official" string representation of an object . If defined, this method repr()will be called when you print an object or use a function. The return value of this method should be an unambiguous, as unambiguous as possible, string representing the current object. The string returned by this method can be useful when you are viewing objects or debugging code.

In your code, __repr__the method is overridden to return dataset_namethe value of the field. This means that when you print an DatasetAttrobject or use repr()the function, you get the value of the object's dataset_namefield .

For example, if you have an DatasetAttrobject whose dataset_namefield value is "my_dataset", then when you print this object, you will get the string "my_dataset" :

 
 

pythonCopy code

dataset_attr = DatasetAttr(load_from="path/to/dataset", dataset_name="my_dataset") print(dataset_attr) # 输出:my_dataset

It is worth noting that __repr__the implementation of this method does not conform to the usual contract, because it does not return an unambiguous string. In general, __repr__the return value of a method should be able to use eval()a function to create a new object identical to the original object. For example, for an Pointobject of class , its method might return a string like "Point(x=1, y=2)". __repr__

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/131487800