Compiler le modèle PyTorch

Cet article est traduit du document anglais Compile PyTorch Models .

L'auteur est Alex Wong .

Plus de documents chinois TVM sont accessibles → TVM Chinese Station .

Cet article décrit comment déployer des modèles PyTorch avec Relay.

PyTorch doit être installé en premier. De plus, TorchVision doit également être installé et utilisé comme collection de modèles (zoo modèle).

Installation rapide via pip :

pip install torch==1.7.0
pip install torchvision==0.8.1

Ou consultez le site officiel : https://pytorch.org/get-started/locally/

La version PyTorch doit être compatible avec la version TorchVision.

Actuellement, TVM prend en charge PyTorch 1.7 et 1.4, d'autres versions peuvent être instables.

import tvm
from tvm import relay

import numpy as np

from tvm.contrib.download import download_testdata

# 导入 PyTorch
import torch
import torchvision

Charger le modèle PyTorch pré-entraîné​

model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()

# 通过追踪获取 TorchScripted 模型
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
输出结果:

Téléchargement : « https://download.pytorch.org/models/resnet18-f37072fd.pth » vers /workspace/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

0%| | 0.00/44.7M [00:00<?, ?B/s]
11%|# | 4.87M/44.7M [00:00<00:00, 51.0MB/s]
22%|##1 | 9.73M/44.7M [00:00<00:00, 49.2MB/s]
74%|#######3 | 32,9 M/44,7 M [00:00<00:00, 136 Mo/s]
100 %|##########| 44,7 M/44,7 M [00:00<00:00, 129 Mo/s]

Charger l'image de test

Exemple de chat classique :

from PIL import Image

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))

# 预处理图像,并将其转换为张量
from torchvision import transforms

my_preprocess = transforms.Compose(
 [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 ]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)

Importer le graphique de calcul dans Relay​

Convertissez le graphique de calcul PyTorch en graphique de calcul Relay. input_name peut être n'importe quelle valeur.

input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

Construction de relais

Avec la spécification d'entrée donnée, compilez le graphe de calcul vers la cible llvm.

target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

Résultat de sortie :

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
 "target_host parameter is going to be deprecated. "

Exécution de graphes informatiques portables sur TVM​

Déployez le modèle compilé sur la cible :

from tvm.contrib import graph_executor

dtype = "float32"
m = graph_executor.GraphModule(lib["default"](dev))
# 设置输入
m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
# 执行
m.run()
# 得到输出
tvm_output = m.get_output(0)

Trouver le nom de la taxonomie

Dans un ensemble de classification de 1000 classes, trouvez la première avec le score le plus élevé :

synset_url = "".join(
 [
 "https://raw.githubusercontent.com/Cadene/",
 "pretrained-models.pytorch/master/data/",
 "imagenet_synsets.txt",
 ]
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
    synsets = f.readlines()

synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}

class_url = "".join(
 [
 "https://raw.githubusercontent.com/Cadene/",
 "pretrained-models.pytorch/master/data/",
 "imagenet_classes.txt",
 ]
)
class_name = "imagenet_classes.txt"
class_path = download_testdata(class_url, class_name, module="data")
with open(class_path) as f:
    class_id_to_key = f.readlines()

class_id_to_key = [x.strip() for x in class_id_to_key]

# 获得 TVM 的前 1 个结果
top1_tvm = np.argmax(tvm_output.numpy()[0])
tvm_class_key = class_id_to_key[top1_tvm]

# 将输入转换为 PyTorch 变量,并获取 PyTorch 结果进行比较
with torch.no_grad():
    torch_img = torch.from_numpy(img)
    output = model(torch_img)

 # 获得 PyTorch 的前 1 个结果
    top1_torch = np.argmax(output.numpy())
    torch_class_key = class_id_to_key[top1_torch]

print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key]))

Résultat de sortie :

Relay top-1 id: 281, class name: tabby, tabby cat
Torch top-1 id: 281, class name: tabby, tabby cat

Téléchargez le code source Python : from_pytorch.py

Télécharger le bloc-notes Jupyter : from_pytorch.ipynb

Je suppose que tu aimes

Origine blog.csdn.net/HyperAI/article/details/130369032
conseillé
Classement