In order to further use C++ to call the deeplabv3+ model, use trace to convert the .pth format generated by pytorch training to .pt
Reference: https://github.com/shanson123/ORB_SLAM2_DeeplabV3/blob/master/DeeplabV3/create_deeplabv3.py
Add in predict.py file:
with torch.no_grad():
model = model.eval()
for img_path in tqdm(image_files):
ext = os.path.basename(img_path).split('.')[-1]
img_name = os.path.basename(img_path)[:-len(ext)-1]
img = Image.open(img_path).convert('RGB')
img = transform(img).unsqueeze(0) # To tensor of NCHW
img = img.to(device)
pred = model(img).max(1)[1].cpu().numpy()[0] # HW
colorized_preds = decode_fn(pred).astype('uint8')
colorized_preds = Image.fromarray(colorized_preds)
if opts.save_val_results_to:
colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png'))
#pth转pt
traced_model = torch.jit.trace(model.module, img.to(device))
traced_model.save("DeeplabV3plus.pt")
Note that if you write traced_model = torch.jit.trace(model, img.to(device)), the following error will appear:
Could not export Python function call 'Scatter'.