Table of contents
Problem 1: Data set inputs and labels are reversed
problem causes
The dataset does not specify inputs and labels when loading.
# Get the column names for input/target.
# 设置 input/target 的逻辑: 1.指定名称(data_args.text_column, data_args.summary_column) 2.指定数据集(自带名称map) 3.默认为(dataset_columns[0], dataset_columns[1])
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
if data_args.text_column is None:
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
text_column = data_args.text_column
if text_column not in column_names:
raise ValueError(
f"--text_column' value '{
data_args.text_column}' needs to be one of: {
', '.join(column_names)}"
)
if data_args.summary_column is None:
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
summary_column = data_args.summary_column
if summary_column not in column_names:
raise ValueError(
f"--summary_column' value '{
data_args.summary_column}' needs to be one of: {
', '.join(column_names)}"
)
The logic of setting input/target:
1. Specify the name (data_args.text_column, data_args.summary_column)
2. Specify the dataset (with its own name map)
3. The default is (dataset_columns[0], dataset_columns[1])
solution
Here we specify when reading parameters. The code here writes a way to read parameters from a file, which requires a parameter file config.json
.
run_summarization.py:311
Load the parameter file on line.
# main, run_summarization.py:311
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
Parameter file content
{
"model_name_or_path":"fnlp_bart-base-chinese",
"text_column":"context",
"summary_column":"response",
"max_source_length":128,
"max_target_length":256,
"dataset_name":"data_douhao",
"num_train_epochs":160,
"save_steps":1000,
"per_device_train_batch_size":32,
"per_device_eval_batch_size":32,
"do_train":true,
"do_eval":true,
"do_predict":false,
"include_inputs_for_metrics":true,
"predict_with_generate":true,
"output_dir":"checkpoints/model_douhao2",
"overwrite_output_dir":true
}
This is used in the parameter file "text_column":"context"
to "summary_column":"response",
specify the inputs and labels.
Other solutions
The logic of setting input/target:
1. Specify the name (data_args.text_column, data_args.summary_column) and
assign the value directly after loading the data set.
data_args.text_column = "context",
data_args.summary_column = "response",
2. The specified data set (with its own name map) has a data set mapping table in the row. Just add a row of data set: tuple mapping
.run_summarization.py:289
summarization_name_mapping = {
"amazon_reviews_multi": ("review_body", "review_title"),
"big_patent": ("description", "abstract"),
......
"wiki_summary": ("article", "highlights"),
"multi_news": ("document", "summary"),
# 加一行即可
# "数据集名称": (input,output)
"data": ("context", "response"),
}
3. The default is (dataset_columns[0], dataset_columns[1])
(got cheated)
Problem 2: The model generation length is always 20.
The generated model length of the trained model is always 20.
problem causes
When loading model hyperparameters in the transformers library, there is a default value max_length = 20
that controls the length of the generated text. When loading the model config file, there is no value set and the default value is automatically loaded. (Default is a bit short)
# 下面按照顺序一层一层进入
# --------------------------------
main, run_summarization.py:417 # 这行加载了BART模型config
config = AutoConfig.from_pretrained(
# --------------------------------
from_pretrained, configuration_auto.py:941 # 加载完 config 文件数值去找对应的模型 config 类了
# 这里模型的 config 里写明了 "model_type": "bart"
# 所以载入时 config_dict["model_type"] = "bart"
return config_class.from_dict(config_dict, **unused_kwargs)
# --------------------------------
from_dict, configuration_utils.py:701 # 同样在找找对应的模型 config 类
config = cls(**config_dict)
# --------------------------------
__init__, configuration_bart.py:165 # 找到对应 bart 模型 config 类,进行初始化
super().__init__(
# --------------------------------
__init__, configuration_utils.py:285 # 用的是通用 config 加载
self.max_length = kwargs.pop("max_length", 20)
Put a breakpoint on line 285 of the last file configuration_utils.py
and you can see the loaded default value.
solution
It is actually easier to change. After loading the config, change the variable config.max_length to 256 and the generated text length can be changed.
config.max_length = 256