torch.save(model.state_dict(), 'model.pth')
torch.save(model.state_dict(), 'model.pth')
是将PyTorch模型的状态字典(state dictionary)保存到文件中的常用代码。具体来说,model.state_dict()
返回一个Python字典,该字典包含了模型所有可学习参数的名称和对应的张量值。torch.save
函数将这个字典保存到文件model.pth
中,以便在需要时重新加载模型参数。
model = MyModel()
model.load_state_dict(torch.load('model.pth'))
model = MyModel() model.load_state_dict(torch.load('model.pth'))
是将训练好的模型参数加载到新模型中的常用代码。具体来说,torch.load
函数会返回一个Python字典,该字典包含了被保存的模型参数。model.load_state_dict
方法将这些参数加载到新的模型实例中,从而创建一个与原模型参数相同的新模型。