torch 保存模型

保存

state = {
    
    'net': model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch': epoch}
torch.save(state, 'dir.pth')        

加载

checkpoint = torch.load(dir)
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1

猜你喜欢

转载自blog.csdn.net/weixin_42764932/article/details/112181920
今日推荐