pytorch版本Deeplabv3+网络模型格式转换(pth转pt)

为了进一步使用c++调用deeplabv3+模型,使用trace将pytorch训练生成的.pth格式转为.pt

参考:https://github.com/shanson123/ORB_SLAM2_DeeplabV3/blob/master/DeeplabV3/create_deeplabv3.py

在predict.py文件中添加:

    with torch.no_grad():
        model = model.eval()
        for img_path in tqdm(image_files):
            ext = os.path.basename(img_path).split('.')[-1]
            img_name = os.path.basename(img_path)[:-len(ext)-1]
            img = Image.open(img_path).convert('RGB')
            img = transform(img).unsqueeze(0) # To tensor of NCHW
            img = img.to(device)
            
            pred = model(img).max(1)[1].cpu().numpy()[0] # HW
            colorized_preds = decode_fn(pred).astype('uint8')
            colorized_preds = Image.fromarray(colorized_preds)
            if opts.save_val_results_to:
                colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png'))

        #pth转pt
        traced_model = torch.jit.trace(model.module, img.to(device))
        traced_model.save("DeeplabV3plus.pt")

注意,如果写成 traced_model = torch.jit.trace(model, img.to(device)),会出现下图的报错:
Could not export Python function call ‘Scatter’.
请添加图片描述

猜你喜欢

转载自blog.csdn.net/slender_1031/article/details/127848808
今日推荐