PytorchモデルをTorchscriptモデルに変換する

PyTorchモデル(サブクラスnn.Module)の中間表現であるTorchScriptは、高性能環境(C ++など)で実行できます。

変換方法は2つあり、1つは変換を追跡する方法、もう1つは注釈を介して変換する方法です。

1.コンバージョンを追跡する

一般的に使用されるのはコンバージョンを追跡することですが、この方法には入力サイズが固定されているという欠点があります。

公式ウェブサイトに掲載されている追跡例:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

トレース ScriptModule は、通常のPyTorchモジュールと同じように評価できるようになりました。

In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

インターネットで見られる最も一般的な例は、コンバージョントラッキングです。

 

2.コメント変換

公式サイトの例

注釈によるトーチスクリプトへの変換

モデルが特定の形式の制御フローを採​​用している場合など、特定の状況では、トーチスクリプトでモデルを直接記述し、それに応じてモデルに注釈を付けることができます。たとえば、次のバニラPytorchモデルがあるとします。

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

forward このモジュールの 方法は、入力に依存する制御フローを使用するため、トレースには適していません。代わりに、に変換でき ScriptModuleます。モジュールをに変換するには ScriptModuletorch.jit.script 次のようにモジュールをコンパイルする必要があります 

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

nn.Module TorchScriptがまだサポートしていないPython機能を使用しているために、メソッドの一部を除外する必要がある場合は 、それらに注釈を付けることができます。 @torch.jit.ignore

my_moduleScriptModule シリアル化の準備ができて いるインスタンスです 

ステップ2:スクリプトモジュールをファイルにシリアル化する

あなたがしたら ScriptModule 、あなたの手の中に、いずれかのトレースやPyTorchモデルに注釈を付けるから、あなたはそれをファイルにシリアライズする準備が整いました。後で、このファイルからC ++でモジュールをロードし、Pythonに依存せずに実行できるようになります。ResNet18 トレースの例で前に示しモデルをシリアル化する とします。このシリアル化を実行する には、モジュールでsave呼び出し 、ファイル名を渡します。

traced_script_module.save("traced_resnet_model.pt")

これにより、traced_resnet_model.pt 作業ディレクトリにファイルが作成さ れます。シリアル化もご希望の場合は my_module、お電話ください my_module.save("my_module_model.pt") 。Pythonの領域を正式に終了し、C ++の領域に移行する準備が整いました。

 

 

 

 

参照:PyTorchモデルをTorchScript形式に変換する

おすすめ

転載: blog.csdn.net/juluwangriyue/article/details/108635280
おすすめ