Pytorch version Deeplabv3+ network model format conversion (pth to pt)

In order to further use C++ to call the deeplabv3+ model, use trace to convert the .pth format generated by pytorch training to .pt

Reference: https://github.com/shanson123/ORB_SLAM2_DeeplabV3/blob/master/DeeplabV3/create_deeplabv3.py

Add in predict.py file:

    with torch.no_grad():
        model = model.eval()
        for img_path in tqdm(image_files):
            ext = os.path.basename(img_path).split('.')[-1]
            img_name = os.path.basename(img_path)[:-len(ext)-1]
            img = Image.open(img_path).convert('RGB')
            img = transform(img).unsqueeze(0) # To tensor of NCHW
            img = img.to(device)
            
            pred = model(img).max(1)[1].cpu().numpy()[0] # HW
            colorized_preds = decode_fn(pred).astype('uint8')
            colorized_preds = Image.fromarray(colorized_preds)
            if opts.save_val_results_to:
                colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png'))

        #pth转pt
        traced_model = torch.jit.trace(model.module, img.to(device))
        traced_model.save("DeeplabV3plus.pt")

Note that if you write traced_model = torch.jit.trace(model, img.to(device)), the following error will appear:
Could not export Python function call 'Scatter'.
Please add image description

Guess you like

Origin blog.csdn.net/slender_1031/article/details/127848808