Transformadores de guardar y modelo de carga | ocho

Autor | huggingface compilado | fuente VK | Github

En esta sección se describe cómo guardar y volver a cargar el modelo de ajuste fino (BERT, GPT, GPT-2 y Transformador-XL). Tendrá que guardar tres tipos de archivos para recargar el modelo ha sido bien:

Los nombres de archivo por defecto de estos archivos son los siguientes:

  • Modelo archivo de pesos:pytorch_model.bin
  • perfiles:config.json
  • Glosario file: vocab.txten nombre del BERT y Transformador-XL, vocab.jsonen nombre de GPT / GPT-2 (BPE vocabulario),
  • En nombre del / GPT-2 (BPE vocabulario) archivo de combinación adicional GPT: merges.txt.

Si se utiliza el nombre de archivo predeterminado para guardar el modelo, puede utilizar el método from_pretrained () para recargar el modelo y tokenizer.

Esto es para guardar el modelo, la configuración, y el método recomendado de archivos de configuración. Palabras al output_dirdirectorio, y luego volver a la carga de modelo y tokenizer:

from transformers import WEIGHTS_NAME, CONFIG_NAME

output_dir = "./models/"

# 步骤1:保存一个经过微调的模型、配置和词汇表

#如果我们有一个分布式模型,只保存封装的模型
#它包装在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model
#如果使用预定义的名称保存,则可以使用`from_pretrained`加载
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)

# 步骤2: 重新加载保存的模型

#Bert模型示例
model = BertForQuestionAnswering.from_pretrained(output_dir)
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case)  # Add specific options if needed
#GPT模型示例
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)

Si desea utilizar una ruta específica para cada tipo de archivo, puede utilizar otro método para guardar y volver a cargar el modelo:

output_model_file = "./models/my_own_model_file.bin"
output_config_file = "./models/my_own_config_file.bin"
output_vocab_file = "./models/my_own_vocab_file.bin"

# 步骤1:保存一个经过微调的模型、配置和词汇表

#如果我们有一个分布式模型,只保存封装的模型
#它包装在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file)

# 步骤2: 重新加载保存的模型

# 我们没有使用预定义权重名称、配置名称进行保存,无法使用`from_pretrained`进行加载。
# 下面是在这种情况下的操作方法:

#Bert模型示例
config = BertConfig.from_json_file(output_config_file)
model = BertForQuestionAnswering(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)

#GPT模型示例
config = OpenAIGPTConfig.from_json_file(output_config_file)
model = OpenAIGPTDoubleHeadsModel(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = OpenAIGPTTokenizer(output_vocab_file)

fuente original: https://huggingface.co/transformers/serialization.html#serialization-best-practices

AI acoge favorablemente la atención Pan Chong blog de la estación: http://panchuang.net/

OpenCV documento oficial chino: http://woshicver.com/

Bienvenido atención Pan Chong blog en recursos estación Resumen: http://docs.panchuang.net/

Publicados 372 artículos originales · ganado elogios 1063 · Vistas de 670.000 +

Supongo que te gusta

Origin blog.csdn.net/fendouaini/article/details/105322537
Recomendado
Clasificación