PyTorchモデル(サブクラスnn.Module)の中間表現であるTorchScriptは、高性能環境(C ++など)で実行できます。
変換方法は2つあり、1つは変換を追跡する方法、もう1つは注釈を介して変換する方法です。
1.コンバージョンを追跡する
一般的に使用されるのはコンバージョンを追跡することですが、この方法には入力サイズが固定されているという欠点があります。
公式ウェブサイトに掲載されている追跡例:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
トレース ScriptModule
は、通常のPyTorchモジュールと同じように評価できるようになりました。
In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
インターネットで見られる最も一般的な例は、コンバージョントラッキングです。
2.コメント変換
公式サイトの例
注釈によるトーチスクリプトへの変換
モデルが特定の形式の制御フローを採用している場合など、特定の状況では、トーチスクリプトでモデルを直接記述し、それに応じてモデルに注釈を付けることができます。たとえば、次のバニラPytorchモデルがあるとします。
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
forward
このモジュールの 方法は、入力に依存する制御フローを使用するため、トレースには適していません。代わりに、に変換でき ScriptModule
ます。モジュールをに変換するには ScriptModule
、torch.jit.script
次のようにモジュールをコンパイルする必要があります 。
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
nn.Module
TorchScriptがまだサポートしていないPython機能を使用しているために、メソッドの一部を除外する必要がある場合は 、それらに注釈を付けることができます。 @torch.jit.ignore
my_module
ScriptModule
シリアル化の準備ができて いるインスタンスです 。
ステップ2:スクリプトモジュールをファイルにシリアル化する
あなたがしたら ScriptModule
、あなたの手の中に、いずれかのトレースやPyTorchモデルに注釈を付けるから、あなたはそれをファイルにシリアライズする準備が整いました。後で、このファイルからC ++でモジュールをロードし、Pythonに依存せずに実行できるようになります。ResNet18
トレースの例で前に示したモデルをシリアル化する とします。このシリアル化を実行する には、モジュールでsaveを呼び出し 、ファイル名を渡します。
traced_script_module.save("traced_resnet_model.pt")
これにより、traced_resnet_model.pt
作業ディレクトリにファイルが作成さ れます。シリアル化もご希望の場合は my_module
、お電話ください my_module.save("my_module_model.pt")
。Pythonの領域を正式に終了し、C ++の領域に移行する準備が整いました。