Cet article est traduit du document anglais Compile PyTorch Models .
L'auteur est Alex Wong .
Plus de documents chinois TVM sont accessibles → TVM Chinese Station .
Cet article décrit comment déployer des modèles PyTorch avec Relay.
PyTorch doit être installé en premier. De plus, TorchVision doit également être installé et utilisé comme collection de modèles (zoo modèle).
Installation rapide via pip :
pip install torch==1.7.0
pip install torchvision==0.8.1
Ou consultez le site officiel : https://pytorch.org/get-started/locally/
La version PyTorch doit être compatible avec la version TorchVision.
Actuellement, TVM prend en charge PyTorch 1.7 et 1.4, d'autres versions peuvent être instables.
import tvm
from tvm import relay
import numpy as np
from tvm.contrib.download import download_testdata
# 导入 PyTorch
import torch
import torchvision
Charger le modèle PyTorch pré-entraîné
model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()
# 通过追踪获取 TorchScripted 模型
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
输出结果:
Téléchargement : « https://download.pytorch.org/models/resnet18-f37072fd.pth » vers /workspace/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
0%| | 0.00/44.7M [00:00<?, ?B/s]
11%|# | 4.87M/44.7M [00:00<00:00, 51.0MB/s]
22%|##1 | 9.73M/44.7M [00:00<00:00, 49.2MB/s]
74%|#######3 | 32,9 M/44,7 M [00:00<00:00, 136 Mo/s]
100 %|##########| 44,7 M/44,7 M [00:00<00:00, 129 Mo/s]
Charger l'image de test
Exemple de chat classique :
from PIL import Image
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
# 预处理图像,并将其转换为张量
from torchvision import transforms
my_preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)
Importer le graphique de calcul dans Relay
Convertissez le graphique de calcul PyTorch en graphique de calcul Relay. input_name peut être n'importe quelle valeur.
input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
Construction de relais
Avec la spécification d'entrée donnée, compilez le graphe de calcul vers la cible llvm.
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
Résultat de sortie :
/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
"target_host parameter is going to be deprecated. "
Exécution de graphes informatiques portables sur TVM
Déployez le modèle compilé sur la cible :
from tvm.contrib import graph_executor
dtype = "float32"
m = graph_executor.GraphModule(lib["default"](dev))
# 设置输入
m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
# 执行
m.run()
# 得到输出
tvm_output = m.get_output(0)
Trouver le nom de la taxonomie
Dans un ensemble de classification de 1000 classes, trouvez la première avec le score le plus élevé :
synset_url = "".join(
[
"https://raw.githubusercontent.com/Cadene/",
"pretrained-models.pytorch/master/data/",
"imagenet_synsets.txt",
]
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
synsets = f.readlines()
synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}
class_url = "".join(
[
"https://raw.githubusercontent.com/Cadene/",
"pretrained-models.pytorch/master/data/",
"imagenet_classes.txt",
]
)
class_name = "imagenet_classes.txt"
class_path = download_testdata(class_url, class_name, module="data")
with open(class_path) as f:
class_id_to_key = f.readlines()
class_id_to_key = [x.strip() for x in class_id_to_key]
# 获得 TVM 的前 1 个结果
top1_tvm = np.argmax(tvm_output.numpy()[0])
tvm_class_key = class_id_to_key[top1_tvm]
# 将输入转换为 PyTorch 变量,并获取 PyTorch 结果进行比较
with torch.no_grad():
torch_img = torch.from_numpy(img)
output = model(torch_img)
# 获得 PyTorch 的前 1 个结果
top1_torch = np.argmax(output.numpy())
torch_class_key = class_id_to_key[top1_torch]
print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key]))
Résultat de sortie :
Relay top-1 id: 281, class name: tabby, tabby cat
Torch top-1 id: 281, class name: tabby, tabby cat