import os
import json
import torch
from typing import Any, Dict, List, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: Optional[str] = field(
default="THUDM/chatglm-6b",
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
)
use_v2: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use ChatGLM2 or not."}
)
config_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained config name or path if not the same as model_name."}
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
)
use_fast_tokenizer: Optional[bool] = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
)
model_revision: Optional[str] = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
)
use_auth_token: Optional[bool] = field(
default=False,
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model."}
)
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4",
metadata={"help": "Quantization data type to use in int4 training."}
)
double_quantization: Optional[bool] = field(
default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
def __post_init__(self):
if self.use_v2 and self.model_name_or_path == "THUDM/chatglm-6b":
self.model_name_or_path = "THUDM/chatglm2-6b"
if self.checkpoint_dir is not None: # support merging lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
This code defines two data classes: DatasetAttr
and ModelArguments
.
-
from typing import Any, Dict, List, Literal, Optional
: Import some type annotation tools fromtyping
the module, includingAny
(any type),Dict
(dictionary type),List
(list type),Literal
(literal type, possible values can be specified) andOptional
(optional, value may be None). -
from dataclasses import asdict, dataclass, field
dataclasses
: Import some data class-related tools from the module, includingasdict
(convert data class objects to dictionaries),dataclass
(decorators for defining data classes), andfield
(for defining fields of data classes). -
Define
DatasetAttr
class:@dataclass
is a decorator that indicatesDatasetAttr
that is a data class.- This class has three fields:
load_from
,dataset_name
anddataset_sha1
, respectively indicate where to load the dataset, the name of the dataset and the sha1 value of the dataset. The last two fields are optional and may be None. - Override
__repr__
the method, when printing the instance of this class, it will returndataset_name
the value of the field. - In
__post_init__
the method, four fields are defined:prompt_column
,query_column
,response_column
andhistory_column
, these fields are automatically assigned after instantiating the object.
-
Define
ModelArguments
class:- This class has many fields, and each field is
field
defined with a function, which can define the default value of the field, as well as some metadata, which may be helpful to understand the meaning of the field. - This class contains some parameters that may be useful for fine-tuning the model, such as model name/path, configuration name, tokenizer name, cache directory, model version, whether to use fast tokenizer, whether to use validation token, quantization bits, quantization type , whether to use double quantization, checkpoint directory, reward model path, whether to continue training, whether to draw a loss map, etc.
- In
__post_init__
the method, it may be modified depending on the values ofuse_v2
and . And, if is not None, it will be split according to . Also, if is not None, asserts that the value of must be 4 or 8.model_name_or_path
model_name_or_path
checkpoint_dir
,
checkpoint_dir
quantization_bit
quantization_bit
- This class has many fields, and each field is
ModelArguments
Each field in the class is explained in detail next :
-
model_name_or_path
: Indicates the name or path of the pre-trained model. Defaults to "THUDM/chatglm-6b", which may be an identifier for a pretrained model. In__post_init__
the method, ifuse_v2
is true andmodel_name_or_path
is "THUDM/chatglm-6b", it will bemodel_name_or_path
changed to "THUDM/chatglm2-6b". -
use_v2
: Indicates whether to use ChatGLM2. Defaults to false. -
config_name
: The name or path of the pre-training configuration.model_name_or_path
This field is required if different from . -
tokenizer_name
: The name or path of the pretrained tokenizer.model_name_or_path
This field is required if different from . -
cache_dir
: location to store pretrained models downloaded from huggingface.co. -
use_fast_tokenizer
: Indicates whether to use the quick tokenizer. Defaults to true. -
model_revision
: The specific model version to use (can be a branch name, tag name, or commit id). Defaults to "main". -
use_auth_token
: Whether to usehuggingface-cli login
tokens generated at runtime . Defaults to false. -
quantization_bit
: Quantize the number of bits of the model. In__post_init__
the method, ifquantization_bit
is not None, it is assertedquantization_bit
that the value of must be 4 or 8. -
quantization_type
: Quantized data type for int4 training. Defaults to "nf4". -
double_quantization
: Whether to use double quantization in int4 training. Defaults to true. -
compute_dtype
: Data type used for quantization configuration. This parameter usually does not need to be specified manually. -
checkpoint_dir
: Path to the directory containing model checkpoints and configurations. In__post_init__
the method, ifcheckpoint_dir
is not None, it will be,
split according tocheckpoint_dir
. -
reward_model
: Path to the directory containing checkpoints for the reward model. -
resume_lora_training
: Indicates whether to continue training from the last LoRA weights, or create new weights after merging weights. Defaults to true. -
plot_loss
: Indicates whether to plot the training loss after fine-tuning. Defaults to false.
ModelArguments
Most of these fields have default values, but custom values can also be provided when the object is created . They can control the fine-tuning process of the model, including which model to use, how to quantize the model, and whether to plot the training loss, etc.
@dataclass please explain in detail
In Python, @dataclass
is a decorator, which is used to automatically add special methods (such as __init__
, __repr__
etc.) to user-defined classes in order to simplify the definition of the class. @dataclass
Module from Python dataclasses
, available in Python 3.7 and later.
Classes that use @dataclass
decorators are called data classes. Data classes are usually used to store some value, and have some special methods that are automatically generated.
For example, we can define a data class to represent the coordinates of a point:
pythonCopy code
from dataclasses import dataclass
@dataclass
class Point:
x: int
y: int
from dataclasses import dataclass @dataclass class Point: x: int y: int
In this example, Point
the class has two fields: x
and y
, which are of type int
.
@dataclass
The decorator automatically adds a __init__
method to Point
the class that __init__
takes two arguments ( x
and y
) and uses them to initialize Point
the object. Point
Therefore, we can create an object like this :
pythonCopy code
p = Point(1, 2)
In addition, @dataclass
the decorator will automatically add some other special methods to the class, including:
__repr__
: Returns a string representing an object , egPoint(x=1, y=2)
.__eq__
: Determine whether two objects are equal. This method assumes by default that objects are equal if they are of the same type and all fields are equal.__hash__
: If the data class is immutable (i.e. all fields are immutable), then a__hash__
method is automatically added. This method is used to calculate the hash value of the object, so that the object can be used as a dictionary key.
In the code you provided, @dataclass
decorators are used to define DatasetAttr
and ModelArguments
these two data classes. This allows you to easily create objects of these classes, and to easily print representations of objects, or compare objects for equality, etc.
__str__
and __repr__
the similarities and differences?
representation of a string
We all know that Python's built-in functions repr()
can express objects in the form of strings , which is convenient for us to identify. This is the " string representation" . repr()
It is through __repr__
this special method that a string representation . If it is not implemented , when we print a vector instance__repr__
in the console , the string we get may be <Vector object at 0x10a514f98>.
__repr__
is a Python special method known as the "official" string representation of an object . If defined, this method repr()
will be called when you print an object or use a function. The return value of this method should be an unambiguous, as unambiguous as possible, string representing the current object. The string returned by this method can be useful when you are viewing objects or debugging code.
In your code, __repr__
the method is overridden to return dataset_name
the value of the field. This means that when you print an DatasetAttr
object or use repr()
the function, you get the value of the object's dataset_name
field .
For example, if you have an DatasetAttr
object whose dataset_name
field value is "my_dataset", then when you print this object, you will get the string "my_dataset" :
pythonCopy code
dataset_attr = DatasetAttr(load_from="path/to/dataset", dataset_name="my_dataset") print(dataset_attr) # 输出:my_dataset
It is worth noting that __repr__
the implementation of this method does not conform to the usual contract, because it does not return an unambiguous string. In general, __repr__
the return value of a method should be able to use eval()
a function to create a new object identical to the original object. For example, for an Point
object of class , its method might return a string like "Point(x=1, y=2)". __repr__