学习自用,欢迎指正
pytorch中训练的.pth模型不能在C++中直接调用,需要转换成.pt格式的文件,故有此文。
在转换之前需要知道的东西:
1.要转换的pth文件是从那个网络训练出来的?
2.训练这个网络时的输入参数是哪些,各是多少
了解之后利用如下代码进行转换:
import torch
from models.network_swinir import SwinIR #你训练的网络
model = SwinIR(upscale=4, in_chans=3, img_size=64, window_size=8,
img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv') # 训练此网络时往里输入的参数
state_dict = torch.load("S_x4.pth") # 要转换的文件路径
model.load_state_dict(state_dict, False)
model.eval() # 切换到eval()
example = torch.rand(1, 3, 320, 480) # 生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")