参看博客:
Pytorch实战总结篇之模型训练、评估与使用_Miracle8070-CSDN博客_pytorch模型评估
程序如下:
# 当训练意外中断,或者想看上次训练结果
# 命令行执行
model = Net().to(device)
model.load_state_dict(torch.load(".\mnist_cnn.pt"))
Out[10]: <All keys matched successfully>
test(model, device, test_dataloader)
test loss: 0.00010815935134887696 , Accuracy: 99.3
# 训练意外或者非意外中断 继续上次训练的模型继续训练
lr = 0.000001 # learning rate
momentum = 0.4
optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum=momentum)
model = Net().to(device)
model.load_state_dict(torch.load(".\mnist_cnn.pt")) # 加载当前目录下的模型(上次训练的模型)
num_epochs = 20000
if __name__ == '__main__': # 多线程 num_worker 不等于0 时
for epoch in range(10000,num_epochs):
train(model, device, train_dataloader, optimizer, epoch)
test(model, device, test_dataloader)
torch.save(model.state_dict(), "mnist_cnn.pt")