pytorch保存与加载模型来测试或继续训练

摘要

pytorch中与保存和加载模型有关函数有三个:
1.torch.save:将序列化的对象保存到磁盘。此函数使用Python的pickle实用程序进行序列化。使用此功能可以保存各种对象的模型,张量和字典。
2. torch.load:使用pickle的unpickle工具将pickle的对象文件反序列化到内存中。即加载save保存的东西。
3. torch.nn.Module.load_state_dict:使用反序列化的state_dict加载模型的参数字典。注意,这意味着它的传入的参数应该是一个state_dict类型,也就torch.load加载出来的。

state_dict

stat_dict是一个字典,该字典包含model每一层的tensor类型的可学习参数。只有包含可学习参数的网络层才能将其参数映射到state_dict字典中,此外,stat_dict也包含优化器的state和超参数。

官网给的一个示例:

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

output:

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])

Optimizer's state_dict:
state    {
    
    }
param_groups     [{
    
    'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

恢复训练实例

保存模型和加载模型的函数如下

def  save_checkpoint_state(dir,epoch,model,optimizer):
	#保存模型
    checkpoint = {
    
    
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
                }   
    if not os.path.isdir(dir):
        os.mkdir(dir)

    torch.save(checkpoint, os.path.join(dir,'checkpoint-epoch%d.tar'%(epoch)))
    
def get_checkpoint_state(dir,ckp_name,device,model,optimizer):
     # 恢复上次的训练状态
    print("Resume from checkpoint...")
    checkpoint = torch.load(os.path.join(dir,ckp_name),map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch=checkpoint['epoch']

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    #scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    print('sucessfully recover from the last state')
    return model,epoch,optimizer

如果加入了lr_scheduler,那么lr_scheduler的state_dict也要加进来。

使用时:

# 引用包省略
#保持模型函数
def save_checkpoint_state(epoch, model, optimizer, scheduler, running_loss):
    checkpoint = {
    
    
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict()
    }
    
    torch.save(checkpoint, "checkpoint-epoch%d-loss%d.tar" % (epoch, running_loss))
# 加载模型函数   
def load_checkpoint_state(path, device, model, optimizer, scheduler):
    checkpoint = torch.load(path, map_location=device)
    
    model.load_state_dict(checkpoint["model_state_dict"])
    
    epoch = checkpoint["epoch"]
    
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    
    return model, epoch, optimizer, scheduler  


# 是否恢复训练
resume = False # True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(): 
    trans = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(512),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # get training dataset
    leafDiseaseCLS = CustomDataSet(images_path, is_to_ls, trans)
    
    data_loader = DataLoader(leafDiseaseCLS,
                             batch_size=16,
                             num_workers=0,
                             shuffle=True,
                             pin_memory=False)
    
    # get model
    model = EfficientNet.from_pretrained("efficientnet-b3")
    
    # extract the parameter of fully connected layer
    fc_features = model._fc.in_features
    # modify the number of classes
    model._fc = nn.Linear(fc_features, 5)
    
    model.to(device)
        
    # optimizer
    optimizer = optim.SGD(model.parameters(), 
                          lr=0.001, 
                          momentum=0.9,
                          weight_decay=5e-4)
    
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[6, 10], gamma=1/3.)
    
    # loss
    #loss_func = nn.CrossEntropyLoss()
    loss_func = FocalCosineLoss()
    
    start_epoch = -1
    
    if resume:
        model, start_epoch, optimizer,scheduler = load_checkpoint_state("../path/to/checkpoint.tar",
                                                                        device, 
                                                                        model,
                                                                        optimizer,
                                                                        scheduler)
    
    model.train()
    
    epochs = 3
    
    for epoch in range(start_epoch + 1, epochs):
        
        running_loss = 0.0
        
        print("Epoch {}/{}".format(epoch, epochs))
        
        for step, train_data in tqdm(enumerate(data_loader)):
            x_train, y_train = train_data
            
            x_train = Variable(x_train.to(device))
            y_train = Variable(y_train.to(device))
            
            # forward
            prediction = model(x_train)
            
            optimizer.zero_grad()
            
            loss = loss_func(prediction, y_train)
            
            running_loss += loss.item()
            
            # backward
            loss.backward()
            
            optimizer.step()            
            
            
        scheduler.step()
        
        # saving model
        torch.save(model.state_dict(), str(int(running_loss)) + "_" + str(epoch) + ".pth")
        
        save_checkpoint_state(epoch, model, optimizer, scheduler, running_loss)
        
        print("Loss:{}".format(running_loss))

if __name__ == "__main__":
    train()

加载部分预训练模型

大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

pretrained_dict = torch.load("model_data/yolo_weights.pth", map_location=device)

model_dict = model.state_dict()
# 将 pretrained_dict 里不属于 model_dict 的键剔除掉
pretrained_dict = {
    
    k: v for k, v in pretrained_dict.items() if k in model_dict}
#pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
# 更新现有的 model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict)

保存 & 加载模型 来inference

保存/加载state_dict (推荐)

保存:

推荐仅仅保存模型的state_dict,保存的时候文件类型可以是.pt或.pth

torch.save(model.state_dict(), PATH)

加载:

在保存模型进行推理时,只需保存已训练模型的学习参数。

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:在测试前必须使用model.eval()把dropout和batch normalization设为测试模式。

保存/加载整个模型

保存:

torch.save(model, PATH)

加载:

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

保存 & 加载一个通用Checkpoint来做测试或恢复训练

保存:

torch.save({
    
    
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

加载:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()

保存用于检查或继续训练的checkpoint时,不仅要保存模型的state_dict,还要保存优化器的state_dict,因为它包含随着模型训练而更新的缓冲区和参数。其他项目需要保存的还有中断训练时的epoch,最新记录的训练损失,外部torch.nn.Embedding层等。
一般使用字典方式保存这些不同部分,然后使用torch.save()序列化字典。PyTorch约定是使用.tar文件扩展名保存这样的checkpoint,并且文件缀名用.tar。这种方式保存的话比单独保存模型文件大2-3倍。
要加载项目,请首先初始化模型和优化器,然后使用torch.load()加载本地的字典(checkpoint)。
如果做inference,必须调用model.eval()来将dropout和batch normalization层设置为评估模式。不这样做将产生不一致的推断结果。如果希望恢复训练,调用model.train()以确保这些层处于训练模式。

加载不同模型的参数warmstarting

保存:

torch.save(modelA.state_dict(), PATH)

加载:

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

在转移学习或训练新的复杂模型时,部分加载模型或加载部分模型是常见方案。利用经过训练的参数,即使只有少数几个可用的参数,也将有助于热启动训练过程,并希望与从头开始训练相比,可以更快地收敛模型。
无论是从缺少部分key的state_dict加载,还是要使用比要加载的模型有更多的key的state_dict加载都行,只需要在load_state_dict()函数中将strict参数设置为False,以忽略不匹配项键。
如果要将参数从一层加载到另一层,但是某些key不匹配,只需更改要加载的state_dict中参数键的名称,以匹配要加载到的模型中的键。

跨设备保存/加载模型(CPU与GPU)

模型保存在GPU上,加载到CPU

保存

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

模型保存在GPU上,加载到GPU

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

一定要使用.to(torch.device(‘cuda’))将所有输入模型的数据转到GPU上。请注意,调用my_tensor.to(device)会在GPU上返回my_tensor的新副本。它不会覆盖my_tensor。因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device(‘cuda’))。

当然还有CPU上训练GPU来加载的,但这种情况较少,就不放操作了.

内容来自pytorch官网
要有看官网文档的心,打破畏难情绪,然后回发现看doc真的还挺简单。

猜你喜欢

转载自blog.csdn.net/yanghao201607030101/article/details/110947689