参考:pytorch公式サイト
pytorchモデルはC ++呼び出し可能シリアル化ファイルに変換され、手順は次のとおりです(traceメソッドを使用)。
1.モデルパラメータをロードし、評価モードに設定します。
2.モデルをトレーニングするときの入力と同じサイズの入力を定義します。
3.順伝播を実行し、トレースを記録して、ptファイルを保存します。
import torch
from MobileNetV2 import mobilenet_v2
model = mobilenet_v2()
state_dict = torch.load("checkpoint/mobilenet-v2_0.pth")
model.load_state_dict(state_dict)
model.eval()
x = torch.rand(1,3,128,128)
ts = torch.jit.trace(model, x)
ts.save('mobilenet_v2.pt')