import os
import json
import torch
from typing import Any, Dict, List, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: Optional[str] = field(
default="THUDM/chatglm-6b",
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
)
use_v2: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use ChatGLM2 or not."}
)
config_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained config name or path if not the same as model_name."}
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
)
use_fast_tokenizer: Optional[bool] = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
)
model_revision: Optional[str] = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
)
use_auth_token: Optional[bool] = field(
default=False,
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model."}
)
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4",
metadata={"help": "Quantization data type to use in int4 training."}
)
double_quantization: Optional[bool] = field(
default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
def __post_init__(self):
if self.use_v2 and self.model_name_or_path == "THUDM/chatglm-6b":
self.model_name_or_path = "THUDM/chatglm2-6b"
if self.checkpoint_dir is not None: # support merging lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
DatasetAttr
このコードは、と の2 つのデータ クラスを定義しますModelArguments
。
-
from typing import Any, Dict, List, Literal, Optional
: (任意の型)、(辞書型)、(リスト型)、(リテラル型、可能な値を指定可能)、および(オプション、値は None の場合もあります)を含む、いくつかの型注釈ツールをtyping
モジュールからインポートします。Any
Dict
List
Literal
Optional
-
from dataclasses import asdict, dataclass, field
dataclasses
: (データ クラスasdict
オブジェクトを辞書に変換する)、(データdataclass
クラスを定義するデコレータ)、(データ クラスのフィールドを定義する) など、いくつかのデータ クラス関連ツールをモジュールからインポートしますfield
。 -
クラスを定義します
DatasetAttr
。@dataclass
DatasetAttr
はデータクラスであることを示すデコレータです。- このクラスには 3 つのフィールドがあります。 、
load_from
、dataset_name
はdataset_sha1
、それぞれ、データセットのロード先、データセットの名前、およびデータセットの sha1 値を示します。最後の 2 つのフィールドはオプションであり、None にすることもできます。 __repr__
このクラスのインスタンスを出力するときにメソッドをオーバーライドすると、dataset_name
フィールドの値が返されます。- このメソッドでは、 、、および
__post_init__
の 4 つのフィールドが定義されています。これらのフィールドは、オブジェクトのインスタンス化後に自動的に割り当てられます。prompt_column
query_column
response_column
history_column
-
クラスを定義します
ModelArguments
。- このクラスには多くのフィールドがあり、各フィールドは
field
フィールドのデフォルト値を定義できる関数と、フィールドの意味を理解するのに役立つメタデータを使用して定義されています。 - このクラスには、モデル名/パス、構成名、トークナイザー名、キャッシュ ディレクトリ、モデル バージョン、高速トークナイザーを使用するかどうか、検証トークンを使用するかどうか、量子化ビット、量子化タイプ、二重量子化を使用するかどうか、チェックポイント ディレクトリ、報酬モデル パス、トレーニングを継続するかどうか、損失マップを描画するかどうかなど。
__post_init__
メソッドでは、use_v2
とmodel_name_or_path
の値に応じて変更される場合がありますmodel_name_or_path
。また、None でない場合は、に従って分割checkpoint_dir
されます。また、が None でない場合、の値は 4 または 8 でなければならないと主張します。,
checkpoint_dir
quantization_bit
quantization_bit
- このクラスには多くのフィールドがあり、各フィールドは
ModelArguments
次に、クラス内の各フィールドについて詳しく説明します。
-
model_name_or_path
: 事前トレーニングされたモデルの名前またはパスを示します。デフォルトは「THUDM/chatglm-6b」で、これは事前トレーニングされたモデルの識別子である可能性があります。__post_init__
メソッド内で、use_v2
が true でmodel_name_or_path
「THUDM/chatglm-6b」の場合、model_name_or_path
「THUDM/chatglm2-6b」に変更されます。 -
use_v2
:ChatGLM2を使用するかどうかを示します。デフォルトは false です。 -
config_name
: 事前トレーニング構成の名前またはパス。と異なる場合、model_name_or_path
このフィールドは必須です。 -
tokenizer_name
: 事前トレーニングされたトークナイザーの名前またはパス。と異なる場合、model_name_or_path
このフィールドは必須です。 -
cache_dir
: ハグフェイス.co からダウンロードした事前トレーニング済みモデルを保存する場所。 -
use_fast_tokenizer
: クイックトークナイザーを使用するかどうかを示します。デフォルトは true です。 -
model_revision
: 使用する特定のモデル バージョン (ブランチ名、タグ名、またはコミット ID を指定できます)。デフォルトは「メイン」です。 -
use_auth_token
huggingface-cli login
: 実行時に生成されたトークンを使用するかどうか。デフォルトは false です。 -
quantization_bit
: モデルのビット数を量子化します。__post_init__
このメソッドでは、 が None でない場合、の値は 4 または 8 でなければならないことquantization_bit
がアサートされます。quantization_bit
-
quantization_type
: int4 トレーニング用の量子化データ型。デフォルトは「nf4」です。 -
double_quantization
: int4 トレーニングで二重量子化を使用するかどうか。デフォルトは true です。 -
compute_dtype
: 量子化設定に使用されるデータ型。通常、このパラメータを手動で指定する必要はありません。 -
checkpoint_dir
: モデルのチェックポイントと構成を含むディレクトリへのパス。__post_init__
メソッド内でNone でない場合はに従って分割checkpoint_dir
されます。,
checkpoint_dir
-
reward_model
: 報酬モデルのチェックポイントを含むディレクトリへのパス。 -
resume_lora_training
: 最後の LoRA ウェイトからトレーニングを継続するか、ウェイトを結合した後に新しいウェイトを作成するかを示します。デフォルトは true です。 -
plot_loss
: 微調整後にトレーニング損失をプロットするかどうかを示します。デフォルトは false です。
これらのフィールドのほとんどにはデフォルト値がありますが、オブジェクトの作成時にModelArguments
カスタム値を指定することもできます。どのモデルを使用するか、モデルを量子化する方法、トレーニング損失をプロットするかどうかなど、モデルの微調整プロセスを制御できます。
@dataclass詳しく説明してください
Python では、クラスの定義を簡素化するために、ユーザー定義クラスに@dataclass
特別なメソッド ( など) を自動的に追加するために使用されるデコレーターです。Python のモジュール。Python 3.7 以降で利用可能。__init__
__repr__
@dataclass
dataclasses
デコレータを使用するクラスは@dataclass
データ クラスと呼ばれます。データ クラスは通常、何らかの値を格納するために使用され、自動的に生成されるいくつかの特別なメソッドを持っています。
たとえば、点の座標を表すデータ クラスを定義できます。
pythonコードをコピー
from dataclasses import dataclass
@dataclass
class Point:
x: int
y: int
from dataclasses import dataclass @dataclass class Point: x: int y: int
この例では、Point
クラスにはx
と という2 つのフィールドがありy
、これらのフィールドの型は ですint
。
@dataclass
デコレーターは、2 つの引数 (と) を受け取り、それらを使用してオブジェクトを初期化するメソッドをクラス__init__
に自動的に追加します。したがって、次のようなオブジェクトを作成できます。Point
__init__
x
y
Point
Point
pythonコードをコピー
p = Point(1, 2)
さらに、@dataclass
デコレータは、次のような他の特別なメソッドをクラスに自動的に追加します。
__repr__
:オブジェクトを表す文字列Point(x=1, y=2)
を返します。__eq__
: 2 つのオブジェクトが等しいかどうかを判断します。このメソッドは、オブジェクトが同じタイプであり、すべてのフィールドが等しい場合、デフォルトでオブジェクトが等しいと想定します。__hash__
: データ クラスが不変の場合 (つまり、すべてのフィールドが不変の場合)、メソッドが__hash__
自動的に追加されます。このメソッドは、オブジェクトを辞書キーとして使用できるように、オブジェクトのハッシュ値を計算するために使用されます。
提供したコードでは、@dataclass
デコレータを使用してこれら 2 つのデータ クラスを定義しDatasetAttr
ますModelArguments
。これにより、これらのクラスのオブジェクトを簡単に作成したり、オブジェクトの表現を簡単に出力したり、オブジェクトが等しいかどうかを比較したりすることができます。
__str__
そして__repr__
類似点と相違点は何ですか?
文字列の表現
Python の組み込み関数はオブジェクトを文字列の形式で表現できるrepr()
ため、識別するのに便利であることは誰もが知っています。これが「文字列表現」です。この特別なメソッドを通じて、オブジェクトの文字列表現。これが実装されていない場合、コンソールでベクター インスタンスを出力すると、取得する文字列は<0x10a514f98 のベクター オブジェクト> になる可能性があります。repr()
__repr__
__repr__
__repr__
は、オブジェクトの「公式」文字列表現として知られる Python の特別なメソッドです。定義されている場合、このメソッドはrepr()
オブジェクトを印刷するとき、または関数を使用するときに呼び出されます。このメソッドの戻り値は、現在のオブジェクトを表す、可能な限り明確な文字列である必要があります。このメソッドによって返される文字列は、オブジェクトを表示するときやコードをデバッグするときに役立ちます。
コードでは、フィールドの値__repr__
を返すようにメソッドがオーバーライドされます。dataset_name
これは、オブジェクトを印刷するDatasetAttr
かrepr()
関数を使用すると、オブジェクトのdataset_name
フィールドの値を取得できることを意味します。
DatasetAttr
たとえば、フィールド値が「my_dataset」であるオブジェクトがある場合、このオブジェクトを印刷すると、文字列「my_dataset」dataset_name
が取得されます。
pythonコードをコピー
dataset_attr = DatasetAttr(load_from="path/to/dataset", dataset_name="my_dataset") print(dataset_attr) # 输出:my_dataset
__repr__
このメソッドの実装は、明確な文字列を返さないため、通常の規約に準拠していないことに注意してください。一般に、__repr__
メソッドの戻り値は、eval()
関数を使用して元のオブジェクトと同一の新しいオブジェクトを作成できる必要があります。たとえば、Point
class のオブジェクトの場合、そのメソッドは「Point(x=1, y=2)」のような文字列を返すことがあります。 __repr__