pytorch训练模型保存与加载

torch.save(model.state_dict(), 'model.pth')

torch.save(model.state_dict(), 'model.pth')是将PyTorch模型的状态字典(state dictionary)保存到文件中的常用代码。具体来说,model.state_dict()返回一个Python字典,该字典包含了模型所有可学习参数的名称和对应的张量值。torch.save函数将这个字典保存到文件model.pth中,以便在需要时重新加载模型参数。

model = MyModel()
model.load_state_dict(torch.load('model.pth'))

model = MyModel() model.load_state_dict(torch.load('model.pth')) 是将训练好的模型参数加载到新模型中的常用代码。具体来说,torch.load函数会返回一个Python字典,该字典包含了被保存的模型参数。model.load_state_dict方法将这些参数加载到新的模型实例中,从而创建一个与原模型参数相同的新模型。

猜你喜欢

转载自blog.csdn.net/weixin_50752408/article/details/129654896