Several ways to download models

Problem Description

As a natural language processing algorithm person, hugging face’s open source transformers package is used very frequently in daily life. During use, you need to download each time you use a new model. If the server used for training has an Internet connection, you can download the model directly by calling the from_pretrained method. But judging from my experience, although this method is convenient, it still has two problems:

  • If the network is very poor, it will take a long time to download the model. It is common for a small model to download for several hours.
  • If you change the training server, you need to download it again.

You may be wondering here why the currently downloaded model cannot be migrated. We can look at the model files saved through from_pretrained (usually in the ~/.cache/huggingface/transformers folder)

!https://s3-us-west-2.amazonaws.com/secure.notion-static.com/79042590-35ff-4181-9c70-1db5bf713183/v2-6a9100687e302faffa91950ac21102f1_720w.jpg

transformers download recommended

from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

model_name = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_name )
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name )

Hugging Face Hub Download Recommended

pip install huggingface_hub

from huggingface_hub import snapshot_download

snapshot_download(repo_id="bert-base-chinese")
# allow_regex和ignore_regex两个参数,简单来说前者是对指定的匹配项进行下载,后者是忽略指定的匹配项,下载其余部分
snapshot_download(repo_id="bert-base-chinese", ignore_regex=["*.h5", "*.ot", "*.msgpack"])

requests download

import os
import json
import requests
from uuid import uuid4
from tqdm import tqdm

SESSIONID = uuid4().hex

VOCAB_FILE = "vocab.txt"
CONFIG_FILE = "config.json"
MODEL_FILE = "pytorch_model.bin"
BASE_URL = "https://huggingface.co/{}/resolve/main/{}"

headers = {'user-agent': 'transformers/4.8.2; python/3.8.5;  \
			session_id/{}; torch/1.9.0; tensorflow/2.5.0; \
			file_type/model; framework/pytorch; from_auto_class/False'.format(SESSIONID)}

model_id = "bert-base-chinese"

# 创建模型对应的文件夹

model_dir = model_id.replace("/", "-")

if not os.path.exists(model_dir):
	os.mkdir(model_dir)

# vocab 和 config 文件可以直接下载

r = requests.get(BASE_URL.format(model_id, VOCAB_FILE), headers=headers)
r.encoding = "utf-8"
with open(os.path.join(model_dir, VOCAB_FILE), "w", encoding="utf-8") as f:
	f.write(r.text)
	print("{}词典文件下载完毕!".format(model_id))

r = requests.get(BASE_URL.format(model_id, CONFIG_FILE), headers=headers)
r.encoding = "utf-8"
with open(os.path.join(model_dir, CONFIG_FILE), "w", encoding="utf-8") as f:
	json.dump(r.json(), f, indent="\t")
	print("{}配置文件下载完毕!".format(model_id))

# 模型文件需要分两步进行

# Step1 获取模型下载的真实地址
r = requests.head(BASE_URL.format(model_id, MODEL_FILE), headers=headers)
r.raise_for_status()
if 300 <= r.status_code <= 399:
	url_to_download = r.headers["Location"]

# Step2 请求真实地址下载模型
r = requests.get(url_to_download, stream=True, proxies=None, headers=None)
r.raise_for_status()

# 这里的进度条是可选项,直接使用了transformers包中的代码
content_length = r.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(
	unit="B",
	unit_scale=True,
	total=total,
	initial=0,
	desc="Downloading Model",
)

with open(os.path.join(model_dir, MODEL_FILE), "wb") as temp_file:
	for chunk in r.iter_content(chunk_size=1024):
		if chunk:  # filter out keep-alive new chunks
			progress.update(len(chunk))
			temp_file.write(chunk)

progress.close()

print("{}模型文件下载完毕!".format(model_id))

Git LFS download

Preparation

The Git LFS solution is much simpler than the previous self-implemented solution. We need to install git lfs on the basis of installing git. Taking Windows as an example, the command is as follows

git lfs install

Model download

We still use bert-base-chinese as an example to download. Open the specific model page and you can see a Use in Transformers button in the upper right corner.

Click the Button and we can see the specific download command.

Execute the copy command in the terminal and you can download it. The format after downloading is the same as the previously implemented code, but in terms of user experience, this method is obviously more elegant!

However, this solution also has certain problems, that is, all files in the warehouse will be downloaded, which will greatly extend the model download time. We can see that the directory contains three different framework model files: flax_model.msgpack, tf_model.h5 and pytorch_model.bin. In the bert-base-uncased version, there is also a rust version of the rust_model.ot model. If we If you only want one version of the model file, this solution cannot be implemented.

If we want to accurately download the model, we can also use Hugging Face Hub. This solution is introduced below.

Guess you like

Origin blog.csdn.net/weixin_42452716/article/details/131121969