LLama-Efficient-Tuning源码解析 src/llmtuner/hparams/generating_args.py

​​​src/llmtuner/hparams/generating_args.py

from typing import Any, Dict, Optional
from dataclasses import asdict, dataclass, field


@dataclass
class GeneratingArguments:
    r"""
    Arguments pertaining to specify the decoding parameters.
    """
    do_sample: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
    )
    temperature: Optional[float] = field(
        default=0.95,
        metadata={"help": "The value used to modulate the next token probabilities."}
    )
    top_p: Optional[float] = field(
        default=0.7,
        metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
    )
    top_k: Optional[int] = field(
        default=50,
        metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
    )
    num_beams: Optional[int] = field(
        default=1,
        metadata={"help": "Number of beams for beam search. 1 means no beam search."}
    )
    max_length: Optional[int] = field(
        default=None,
        metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
    )
    max_new_tokens: Optional[int] = field(
        default=512,
        metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
    )
    repetition_penalty: Optional[float] = field(
        default=1.0,
        metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
    )
    length_penalty: Optional[float] = field(
        default=1.0,
        metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
    )

    def to_dict(self) -> Dict[str, Any]:
        args = asdict(self)
        if args.get("max_new_tokens", None):
            args.pop("max_length", None)
        return args

generating_args.py

pythonCopy code

from typing import Any, Dict, Optional from dataclasses import asdict, dataclass, field

  • 这两行导入了需要的模块和函数。
    • typing 提供了类型提示相关的模块,如 Any, Dict, 和 Optional
    • dataclasses 提供了数据类的相关功能,如 asdict, dataclass, 和 field
 
 

pythonCopy code

@dataclass

  • 这是一个装饰器,它表示下面的类 GeneratingArguments 是一个数据类。数据类主要用于简化类的创建和管理。
 
 

pythonCopy code

class GeneratingArguments:

  • 定义了一个名为 GeneratingArguments 的类。
 
 

pythonCopy code

r""" Arguments pertaining to specify the decoding parameters. """

  • 这是类的文档字符串(docstring),它描述了这个类的主要功能。r 表示这是一个原始字符串,其中的所有字符都是按照它们的字面含义解释的。

接下来的几行定义了类的属性:

 
 

pythonCopy code

do_sample: Optional[bool] = field( ... ) ...

  • do_sampletemperaturetop_p 等是 GeneratingArguments 类的属性。
  • Optional[type] 表示该属性可以是指定的 type 或者 None
  • field(...) 用于为属性提供额外的配置,比如默认值和元数据。
 
 

pythonCopy code

def to_dict(self) -> Dict[str, Any]:

  • 这是一个方法,它的功能是将 GeneratingArguments 对象的属性转换为字典。
 
 

pythonCopy code

args = asdict(self)

  • 使用 asdict 函数将当前对象(self)转换为一个字典,并存储在 args 变量中。
 
 

pythonCopy code

if args.get("max_new_tokens", None):

  • 检查 args 字典中是否有名为 "max_new_tokens" 的键,如果有并且其值不为 None,则进入下一行。
 
 

pythonCopy code

args.pop("max_length", None)

  • args 字典中删除名为 "max_length" 的键及其对应的值。如果 "max_length" 键不存在,该操作仍然会正常执行,因为为 pop 方法提供了默认值 None
 
 

pythonCopy code

return args

  • 返回处理后的 args 字典。

整体来说,这段代码定义了一个名为 GeneratingArguments 的数据类,该类主要用于配置解码参数。这些参数可能与自然语言处理任务中的文本生成有关,如用于指定解码策略(贪婪解码、采样、束搜索等)。

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/132793661