pytorchモデルをptに変換する(c ++呼び出し可能)

参考:pytorch公式サイト

pytorchモデルはC ++呼び出し可能シリアル化ファイルに変換され、手順は次のとおりです(traceメソッドを使用)。

1.モデルパラメータをロードし、評価モードに設定します。

2.モデルをトレーニングするときの入力と同じサイズの入力を定義します。

3.順伝播を実行し、トレースを記録して、ptファイルを保存します。

import torch
from MobileNetV2 import mobilenet_v2


model = mobilenet_v2()
state_dict = torch.load("checkpoint/mobilenet-v2_0.pth")
model.load_state_dict(state_dict)
model.eval()


x = torch.rand(1,3,128,128)
ts = torch.jit.trace(model, x)
ts.save('mobilenet_v2.pt')

 

おすすめ

転載: blog.csdn.net/Guo_Python/article/details/108825011