训练中断,继续上次训练结果继续训练

 参看博客:

Pytorch实战总结篇之模型训练、评估与使用_Miracle8070-CSDN博客_pytorch模型评估

程序如下:

# 当训练意外中断,或者想看上次训练结果
# 命令行执行
model = Net().to(device)

model.load_state_dict(torch.load(".\mnist_cnn.pt"))
Out[10]: <All keys matched successfully>

test(model, device, test_dataloader)
test loss: 0.00010815935134887696   , Accuracy: 99.3


# 训练意外或者非意外中断 继续上次训练的模型继续训练
lr = 0.000001 # learning rate
momentum = 0.4
optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum=momentum)
model = Net().to(device)

model.load_state_dict(torch.load(".\mnist_cnn.pt")) # 加载当前目录下的模型(上次训练的模型)
num_epochs = 20000

if __name__ == '__main__':  # 多线程 num_worker 不等于0 时
    for epoch in range(10000,num_epochs):
        train(model, device, train_dataloader, optimizer, epoch)
        test(model, device, test_dataloader)
        
    torch.save(model.state_dict(), "mnist_cnn.pt")

猜你喜欢

转载自blog.csdn.net/huachuchengzhang/article/details/122832232