Remember the experience of repeating the BART model’s pitfalls

Problem 1: Data set inputs and labels are reversed

problem causes

The dataset does not specify inputs and labels when loading.

    # Get the column names for input/target.
    # 设置 input/target 的逻辑: 1.指定名称(data_args.text_column, data_args.summary_column) 2.指定数据集(自带名称map) 3.默认为(dataset_columns[0], dataset_columns[1])
    dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
    if data_args.text_column is None:
        text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
    else:
        text_column = data_args.text_column
        if text_column not in column_names:
            raise ValueError(
                f"--text_column' value '{
      
      data_args.text_column}' needs to be one of: {
      
      ', '.join(column_names)}"
            )
    if data_args.summary_column is None:
        summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        summary_column = data_args.summary_column
        if summary_column not in column_names:
            raise ValueError(
                f"--summary_column' value '{
      
      data_args.summary_column}' needs to be one of: {
      
      ', '.join(column_names)}"
            )

The logic of setting input/target:
1. Specify the name (data_args.text_column, data_args.summary_column)
2. Specify the dataset (with its own name map)
3. The default is (dataset_columns[0], dataset_columns[1])

solution

Here we specify when reading parameters. The code here writes a way to read parameters from a file, which requires a parameter file config.json.

run_summarization.py:311Load the parameter file on line.

    # main, run_summarization.py:311
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))

Parameter file content

{
    
    
  "model_name_or_path":"fnlp_bart-base-chinese",

  "text_column":"context",
  "summary_column":"response",
  "max_source_length":128,
  "max_target_length":256,

  "dataset_name":"data_douhao",
  "num_train_epochs":160,
  "save_steps":1000,
  "per_device_train_batch_size":32,
  "per_device_eval_batch_size":32,
  "do_train":true,
  "do_eval":true,
  "do_predict":false,
  "include_inputs_for_metrics":true,
  "predict_with_generate":true,
  "output_dir":"checkpoints/model_douhao2",
  "overwrite_output_dir":true
}

This is used in the parameter file "text_column":"context"to "summary_column":"response",specify the inputs and labels.

Other solutions

The logic of setting input/target:
1. Specify the name (data_args.text_column, data_args.summary_column) and
assign the value directly after loading the data set.

  data_args.text_column = "context",
  data_args.summary_column = "response",

2. The specified data set (with its own name map) has a data set mapping table in the row. Just add a row of data set: tuple mapping
.run_summarization.py:289

summarization_name_mapping = {
    
    
    "amazon_reviews_multi": ("review_body", "review_title"),
    "big_patent": ("description", "abstract"),
    ......
    "wiki_summary": ("article", "highlights"),
    "multi_news": ("document", "summary"),

	# 加一行即可
	# "数据集名称": (input,output)
    "data": ("context", "response"),
}

3. The default is (dataset_columns[0], dataset_columns[1])
(got cheated)

Problem 2: The model generation length is always 20.

The generated model length of the trained model is always 20.

problem causes

When loading model hyperparameters in the transformers library, there is a default value max_length = 20that controls the length of the generated text. When loading the model config file, there is no value set and the default value is automatically loaded. (Default is a bit short)

# 下面按照顺序一层一层进入
# --------------------------------
main, run_summarization.py:417  # 这行加载了BART模型config
config = AutoConfig.from_pretrained(
# --------------------------------
from_pretrained, configuration_auto.py:941  # 加载完 config 文件数值去找对应的模型 config 类了
# 这里模型的 config 里写明了 "model_type": "bart"
# 所以载入时 config_dict["model_type"] = "bart"
return config_class.from_dict(config_dict, **unused_kwargs)
# --------------------------------
from_dict, configuration_utils.py:701   # 同样在找找对应的模型 config 类
config = cls(**config_dict)
# --------------------------------
__init__, configuration_bart.py:165 # 找到对应 bart 模型 config 类,进行初始化
super().__init__(
# --------------------------------
__init__, configuration_utils.py:285    # 用的是通用 config 加载
self.max_length = kwargs.pop("max_length", 20)

Put a breakpoint on line 285 of the last file configuration_utils.pyand you can see the loaded default value.

Why is the default value only 20? Really?

solution

It is actually easier to change. After loading the config, change the variable config.max_length to 256 and the generated text length can be changed.

config.max_length = 256

Guess you like

Origin blog.csdn.net/aiaidexiaji/article/details/131063936