Storage and loading of pytorch models

pytorch save and load model
1. Related functions

torch.save

torch.save(obj, f, pickle_module=pickle, pickle_protocol=2)

torch.load

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

map_location Choose to load into CPU or GPU

# 保存在 CPU, 加载到 GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) 

# 保存在 GPU, 加载到 CPU

model.load_state_dict(torch.load(PATH, map_location='cpu'))

model.load_state_dict()
 

model.load_state_dict(state_dict, strict=True)

2. Save and load directly

Save and load the entire model (already trained, no need to continue training) occupying memory

# 保存
torch.save(model, PATH)
# 加载
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

3. Use state_dictsave to load (recommended)

Only the weight parameters are reserved state_dict, so the model needs to be initialized first when loading

Otherwise,  pytorch AttributeError will  appear

Save and load state_dict (already trained, no need to continue training)

keep

torch.save(model.state_dict(), PATH)

load

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()   #一定要初始化  不然会报错

It is generally saved as a file in .ptor .pth format.

1. The load_state_dict() function requires a dict type input instead of the PATH to save the model. So model.load_state_dict(PATH) is wrong, but model.load_state_dict(torch.load(PATH)) should be used .
2. If you want to save the model with the best performance on the verification machine, then best_model_state=model.state_dict() is wrong. Because this is a shallow copy, that is to say, the best_model_state will be continuously updated with the subsequent training process at this time, and the last saved model is actually an overfit model. So the correct approach should be best_model_state=deepcopy(model.state_dict()).

 Save and load state_dict (not finished training, will continue to train)

keep

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...你自己的参数
            }, PATH)

load

model = XIAOHU(*args, **kwargs)
optimizer = adam(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
...
model.eval()
# - or -
model.train()

Guess you like

Origin blog.csdn.net/qq_37925923/article/details/126919333