src/llmtuner/hparams/generating_args.py
from typing import Any, Dict, Optional
from dataclasses import asdict, dataclass, field
@dataclass
class GeneratingArguments:
r"""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
)
temperature: Optional[float] = field(
default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."}
)
top_p: Optional[float] = field(
default=0.7,
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
)
top_k: Optional[int] = field(
default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
)
num_beams: Optional[int] = field(
default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
)
max_length: Optional[int] = field(
default=None,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
)
max_new_tokens: Optional[int] = field(
default=512,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
)
repetition_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
)
length_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
)
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", None):
args.pop("max_length", None)
return args
generating_args.py
pythonCopy code
from typing import Any, Dict, Optional from dataclasses import asdict, dataclass, field
- 这两行导入了需要的模块和函数。
typing
提供了类型提示相关的模块,如Any
,Dict
, 和Optional
。dataclasses
提供了数据类的相关功能,如asdict
,dataclass
, 和field
。
pythonCopy code
@dataclass
- 这是一个装饰器,它表示下面的类
GeneratingArguments
是一个数据类。数据类主要用于简化类的创建和管理。
pythonCopy code
class GeneratingArguments:
- 定义了一个名为
GeneratingArguments
的类。
pythonCopy code
r""" Arguments pertaining to specify the decoding parameters. """
- 这是类的文档字符串(docstring),它描述了这个类的主要功能。
r
表示这是一个原始字符串,其中的所有字符都是按照它们的字面含义解释的。
接下来的几行定义了类的属性:
pythonCopy code
do_sample: Optional[bool] = field( ... ) ...
do_sample
、temperature
、top_p
等是GeneratingArguments
类的属性。Optional[type]
表示该属性可以是指定的type
或者None
。field(...)
用于为属性提供额外的配置,比如默认值和元数据。
pythonCopy code
def to_dict(self) -> Dict[str, Any]:
- 这是一个方法,它的功能是将
GeneratingArguments
对象的属性转换为字典。
pythonCopy code
args = asdict(self)
- 使用
asdict
函数将当前对象(self
)转换为一个字典,并存储在args
变量中。
pythonCopy code
if args.get("max_new_tokens", None):
- 检查
args
字典中是否有名为 "max_new_tokens" 的键,如果有并且其值不为None
,则进入下一行。
pythonCopy code
args.pop("max_length", None)
- 从
args
字典中删除名为 "max_length" 的键及其对应的值。如果 "max_length" 键不存在,该操作仍然会正常执行,因为为pop
方法提供了默认值None
。
pythonCopy code
return args
- 返回处理后的
args
字典。
整体来说,这段代码定义了一个名为 GeneratingArguments
的数据类,该类主要用于配置解码参数。这些参数可能与自然语言处理任务中的文本生成有关,如用于指定解码策略(贪婪解码、采样、束搜索等)。