ソースコード分析 ChatGLM Efficient Tuning/src/utils/config.py

import os
import json
import torch
from typing import Any, Dict, List, Literal, Optional
from dataclasses import asdict, dataclass, field


@dataclass
class DatasetAttr:

    load_from: str
    dataset_name: Optional[str] = None
    dataset_sha1: Optional[str] = None

    def __repr__(self) -> str:
        return self.dataset_name

    def __post_init__(self):
        self.prompt_column = "instruction"
        self.query_column = "input"
        self.response_column = "output"
        self.history_column = None


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
    """
    model_name_or_path: Optional[str] = field(
        default="THUDM/chatglm-6b",
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
    )
    use_v2: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use ChatGLM2 or not."}
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={"help": "Pretrained config name or path if not the same as model_name."}
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
    )
    use_fast_tokenizer: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
    )
    model_revision: Optional[str] = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
    )
    use_auth_token: Optional[bool] = field(
        default=False,
        metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
    )
    quantization_bit: Optional[int] = field(
        default=None,
        metadata={"help": "The number of bits to quantize the model."}
    )
    quantization_type: Optional[Literal["fp4", "nf4"]] = field(
        default="nf4",
        metadata={"help": "Quantization data type to use in int4 training."}
    )
    double_quantization: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use double quantization in int4 training or not."}
    )
    compute_dtype: Optional[torch.dtype] = field(
        default=None,
        metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
    )
    checkpoint_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
    )
    reward_model: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
    )
    resume_lora_training: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
    )
    plot_loss: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
    )

    def __post_init__(self):
        if self.use_v2 and self.model_name_or_path == "THUDM/chatglm-6b":
            self.model_name_or_path = "THUDM/chatglm2-6b"

        if self.checkpoint_dir is not None: # support merging lora weights
            self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]

        if self.quantization_bit is not None:
            assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."

DatasetAttrこのコードは、と の2 つのデータ クラスを定義しますModelArguments

  1. from typing import Any, Dict, List, Literal, Optional: (任意の型)、(辞書型)、(リスト型)、(リテラル型、可能な値を指定可能)、および(オプション、値は None の場合もあります)を含む、いくつかの型注釈ツールをtypingモジュールからインポートします。AnyDictListLiteralOptional

  2. from dataclasses import asdict, dataclass, fielddataclasses: (データ クラスasdictオブジェクトを辞書に変換する)、(データdataclassクラスを定義するデコレータ)、(データ クラスのフィールドを定義する) など、いくつかのデータ クラス関連ツールをモジュールからインポートしますfield

  3. クラスを定義しますDatasetAttr

    • @dataclassDatasetAttrはデータクラスであることを示すデコレータです。
    • このクラスには 3 つのフィールドがあります。 、load_fromdataset_namedataset_sha1、それぞれ、データセットのロード先、データセットの名前、およびデータセットの sha1 値を示します。最後の 2 つのフィールドはオプションであり、None にすることもできます。
    • __repr__このクラスのインスタンスを出力するときにメソッドをオーバーライドすると、dataset_nameフィールドの値が返されます。
    • このメソッドでは、 および__post_init__の 4 つのフィールドが定義されています。これらのフィールドは、オブジェクトのインスタンス化後に自動的に割り当てられます。prompt_columnquery_columnresponse_columnhistory_column
  4. クラスを定義しますModelArguments

    • このクラスには多くのフィールドがあり、各フィールドはfieldフィールドのデフォルト値を定義できる関数と、フィールドの意味を理解するのに役立つメタデータを使用して定義されています。
    • このクラスには、モデル名/パス、構成名、トークナイザー名、キャッシュ ディレクトリ、モデル バージョン、高速トークナイザーを使用するかどうか、検証トークンを使用するかどうか、量子化ビット、量子化タイプ、二重量子化を使用するかどうか、チェックポイント ディレクトリ、報酬モデル パス、トレーニングを継続するかどうか、損失マップを描画するかどうかなど。
    • __post_init__メソッドでは、use_v2model_name_or_pathの値に応じて変更される場合がありますmodel_name_or_pathまた、None でない場合は、に従って分割checkpoint_dirされますまた、が None でない場合、の値は 4 または 8 でなければならないと主張します。,checkpoint_dirquantization_bitquantization_bit

ModelArguments次に、クラス内の各フィールドについて詳しく説明します。

  • model_name_or_path: 事前トレーニングされたモデルの名前またはパスを示します。デフォルトは「THUDM/chatglm-6b」で、これは事前トレーニングされたモデルの識別子である可能性があります。__post_init__メソッド内で、use_v2が true でmodel_name_or_path「THUDM/chatglm-6b」の場合、model_name_or_path「THUDM/chatglm2-6b」に変更されます。

  • use_v2:ChatGLM2を使用するかどうかを示します。デフォルトは false です。

  • config_name: 事前トレーニング構成の名前またはパス。と異なる場合、model_name_or_pathこのフィールドは必須です。

  • tokenizer_name: 事前トレーニングされたトークナイザーの名前またはパス。と異なる場合、model_name_or_pathこのフィールドは必須です。

  • cache_dir: ハグフェイス.co からダウンロードした事前トレーニング済みモデルを保存する場所。

  • use_fast_tokenizer: クイックトークナイザーを使用するかどうかを示します。デフォルトは true です。

  • model_revision: 使用する特定のモデル バージョン (ブランチ名、タグ名、またはコミット ID を指定できます)。デフォルトは「メイン」です。

  • use_auth_tokenhuggingface-cli login: 実行時に生成されたトークンを使用するかどうかデフォルトは false です。

  • quantization_bit: モデルのビット数を量子化します。__post_init__このメソッドでは、 が None でない場合、の値は 4 または 8 でなければならないことquantization_bitがアサートされます。quantization_bit

  • quantization_type: int4 トレーニング用の量子化データ型。デフォルトは「nf4」です。

  • double_quantization: int4 トレーニングで二重量子化を使用するかどうか。デフォルトは true です。

  • compute_dtype: 量子化設定に使用されるデータ型。通常、このパラメータを手動で指定する必要はありません。

  • checkpoint_dir: モデルのチェックポイントと構成を含むディレクトリへのパス。__post_init__メソッド内でNone でない場合はに従って分割checkpoint_dirされます,checkpoint_dir

  • reward_model: 報酬モデルのチェックポイントを含むディレクトリへのパス。

  • resume_lora_training: 最後の LoRA ウェイトからトレーニングを継続するか、ウェイトを結合した後に新しいウェイトを作成するかを示します。デフォルトは true です。

  • plot_loss: 微調整後にトレーニング損失をプロットするかどうかを示します。デフォルトは false です。

これらのフィールドのほとんどにはデフォルト値がありますが、オブジェクトの作成時にModelArgumentsカスタム値を指定することもできます。どのモデルを使用するか、モデルを量子化する方法、トレーニング損失をプロットするかどうかなど、モデルの微調整プロセスを制御できます。

@dataclass詳しく説明してください

Python では、クラスの定義簡素化するために、ユーザー定義クラスに@dataclass特別なメソッド ( など) を自動的に追加するために使用されるデコレーターです。Python のモジュール。Python 3.7 以降で利用可能。__init____repr__@dataclassdataclasses

デコレータを使用するクラスは@dataclassデータ クラスと呼ばれます。データ クラスは通常、何らかの値を格納するために使用され、自動的に生成されるいくつかの特別なメソッドを持っています。

たとえば、点の座標を表すデータ クラスを定義できます。

 

pythonコードをコピー

from dataclasses import dataclass

@dataclass
class Point:
    x: int
    y: int

from dataclasses import dataclass @dataclass class Point: x: int y: int

この例では、Pointクラスにはxと という2 つのフィールドがありy、これらのフィールドの型は ですint

@dataclassデコレーターは、2 つの引数 () を受け取り、それらを使用してオブジェクトを初期化するメソッドをクラス__init__に自動的に追加します。したがって、次のようなオブジェクトを作成できます。Point__init__xyPointPoint

 

pythonコードをコピー

p = Point(1, 2)

さらに、@dataclassデコレータは、次のような他の特別なメソッドをクラスに自動的に追加します。

  • __repr__:オブジェクトを表す文字列Point(x=1, y=2)を返します
  • __eq__: 2 つのオブジェクトが等しいかどうかを判断します。このメソッドは、オブジェクトが同じタイプであり、すべてのフィールドが等しい場合、デフォルトでオブジェクトが等しいと想定します。
  • __hash__: データ クラスが不変の場合 (つまり、すべてのフィールドが不変の場合)、メソッドが__hash__自動的に追加されます。このメソッドは、オブジェクトを辞書キーとして使用できるように、オブジェクトのハッシュ値を計算するために使用されます。

提供したコードでは、@dataclassデコレータを使用してこれら 2 つのデータ クラスを定義しDatasetAttrますModelArgumentsこれにより、これらのクラスのオブジェクトを簡単に作成したり、オブジェクトの表現を簡単に出力したり、オブジェクトが等しいかどうかを比較したりすることができます。

__str__そして__repr__類似点と相違点は何ですか?

文字列の表現

Python の組み込み関数はオブジェクトを文字列の形式で表現できるrepr()ため、識別するのに便利であることは誰もが知っています。これが「文字列表現」です。この特別なメソッドを通じてオブジェクトの文字列表現。これが実装されていない場合、コンソールでベクター インスタンスを出力すると、取得する文字列は<0x10a514f98 のベクター オブジェクト> になる可能性があります。repr()__repr____repr__

__repr__は、オブジェクトの「公式」文字列表現として知られる Python の特別なメソッドです定義されている場合、このメソッドはrepr()オブジェクトを印刷するとき、または関数を使用するときに呼び出されます。このメソッドの戻り値は、現在のオブジェクトを表す、可能な限り明確な文字列である必要があります。このメソッドによって返される文字列は、オブジェクトを表示するときやコードをデバッグするときに役立ちます。

コードでは、フィールドの値__repr__を返すようにメソッドがオーバーライドされます。dataset_nameこれは、オブジェクトを印刷するDatasetAttrrepr()関数を使用すると、オブジェクトのdataset_nameフィールドの値を取得できることを意味します。

DatasetAttrたとえば、フィールド値が「my_dataset」であるオブジェクトがある場合、このオブジェクトを印刷すると、文字列「my_dataset」dataset_nameが取得されます

 
 

pythonコードをコピー

dataset_attr = DatasetAttr(load_from="path/to/dataset", dataset_name="my_dataset") print(dataset_attr) # 输出:my_dataset

__repr__このメソッドの実装は、明確な文字列を返さないため、通常の規約に準拠していないことに注意してください一般に、__repr__メソッドの戻り値は、eval()関数を使用して元のオブジェクトと同一の新しいオブジェクトを作成できる必要があります。たとえば、Pointclass のオブジェクトの場合、そのメソッドは「Point(x=1, y=2)」のような文字列を返すことがあります。 __repr__

おすすめ

転載: blog.csdn.net/sinat_37574187/article/details/131487800