pytorch saves and loads the model to test or continue training

Summary

There are three functions related to saving and loading models in
pytorch: 1.torch.save: save the serialized object to disk. This function uses Python's pickle utility for serialization. Use this function to save models, tensors and dictionaries of various objects.
2. torch.load: Use pickle's unpickle tool to deserialize the pickled object file into memory. That is, load the things saved by save.
3. torch.nn.Module.load_state_dict: Use the deserialized state_dict to load the parameter dictionary of the model. Note that this means that its incoming parameter should be a state_dict type, which is loaded by torch.load.

state_dict

stat_dict is a dictionary that contains the learnable parameters of the tensor type of each layer of the model. Only the network layer that contains learnable parameters can map its parameters to the state_dict dictionary. In addition, stat_dict also contains the state and hyperparameters of the optimizer.

An example given by the official website:

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

output:

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])

Optimizer's state_dict:
state    {
    
    }
param_groups     [{
    
    'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

Recovery training example

The functions to save the model and load the model are as follows

def  save_checkpoint_state(dir,epoch,model,optimizer):
	#保存模型
    checkpoint = {
    
    
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
                }   
    if not os.path.isdir(dir):
        os.mkdir(dir)

    torch.save(checkpoint, os.path.join(dir,'checkpoint-epoch%d.tar'%(epoch)))
    
def get_checkpoint_state(dir,ckp_name,device,model,optimizer):
     # 恢复上次的训练状态
    print("Resume from checkpoint...")
    checkpoint = torch.load(os.path.join(dir,ckp_name),map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch=checkpoint['epoch']

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    #scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    print('sucessfully recover from the last state')
    return model,epoch,optimizer

If lr_scheduler is added, the state_dict of lr_scheduler must also be added.

when using it:

# 引用包省略
#保持模型函数
def save_checkpoint_state(epoch, model, optimizer, scheduler, running_loss):
    checkpoint = {
    
    
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict()
    }
    
    torch.save(checkpoint, "checkpoint-epoch%d-loss%d.tar" % (epoch, running_loss))
# 加载模型函数   
def load_checkpoint_state(path, device, model, optimizer, scheduler):
    checkpoint = torch.load(path, map_location=device)
    
    model.load_state_dict(checkpoint["model_state_dict"])
    
    epoch = checkpoint["epoch"]
    
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    
    return model, epoch, optimizer, scheduler  


# 是否恢复训练
resume = False # True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(): 
    trans = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(512),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # get training dataset
    leafDiseaseCLS = CustomDataSet(images_path, is_to_ls, trans)
    
    data_loader = DataLoader(leafDiseaseCLS,
                             batch_size=16,
                             num_workers=0,
                             shuffle=True,
                             pin_memory=False)
    
    # get model
    model = EfficientNet.from_pretrained("efficientnet-b3")
    
    # extract the parameter of fully connected layer
    fc_features = model._fc.in_features
    # modify the number of classes
    model._fc = nn.Linear(fc_features, 5)
    
    model.to(device)
        
    # optimizer
    optimizer = optim.SGD(model.parameters(), 
                          lr=0.001, 
                          momentum=0.9,
                          weight_decay=5e-4)
    
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[6, 10], gamma=1/3.)
    
    # loss
    #loss_func = nn.CrossEntropyLoss()
    loss_func = FocalCosineLoss()
    
    start_epoch = -1
    
    if resume:
        model, start_epoch, optimizer,scheduler = load_checkpoint_state("../path/to/checkpoint.tar",
                                                                        device, 
                                                                        model,
                                                                        optimizer,
                                                                        scheduler)
    
    model.train()
    
    epochs = 3
    
    for epoch in range(start_epoch + 1, epochs):
        
        running_loss = 0.0
        
        print("Epoch {}/{}".format(epoch, epochs))
        
        for step, train_data in tqdm(enumerate(data_loader)):
            x_train, y_train = train_data
            
            x_train = Variable(x_train.to(device))
            y_train = Variable(y_train.to(device))
            
            # forward
            prediction = model(x_train)
            
            optimizer.zero_grad()
            
            loss = loss_func(prediction, y_train)
            
            running_loss += loss.item()
            
            # backward
            loss.backward()
            
            optimizer.step()            
            
            
        scheduler.step()
        
        # saving model
        torch.save(model.state_dict(), str(int(running_loss)) + "_" + str(epoch) + ".pth")
        
        save_checkpoint_state(epoch, model, optimizer, scheduler, running_loss)
        
        print("Loss:{}".format(running_loss))

if __name__ == "__main__":
    train()

Load part of the pre-trained model

Most of the time we need to adjust our model according to our tasks, so it is difficult to ensure that the model is exactly the same as the public model, but the parameters of the pre-training model do help to improve the accuracy of training. In order to combine the advantages of the two, We need to load some pre-trained models.

pretrained_dict = torch.load("model_data/yolo_weights.pth", map_location=device)

model_dict = model.state_dict()
# 将 pretrained_dict 里不属于 model_dict 的键剔除掉
pretrained_dict = {
    
    k: v for k, v in pretrained_dict.items() if k in model_dict}
#pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
# 更新现有的 model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict)

Save & load model to inference

Save/load state_dict (recommended)

Save:

It is recommended to save only the state_dict of the model. The file type can be .pt or .pth when saving.

torch.save(model.state_dict(), PATH)

load:

When saving the model for inference, only the learning parameters of the trained model need to be saved.

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

Note: You must use model.eval() to set dropout and batch normalization to test mode before testing.

Save/load the entire model

Save:

torch.save(model, PATH)

load:

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

Save & load a general Checkpoint for testing or resuming training

Save:

torch.save({
    
    
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

load:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()

When saving the checkpoint for checking or continuing training, not only the state_dict of the model must be saved, but also the state_dict of the optimizer, because it contains the buffers and parameters that are updated as the model is trained. Other items that need to be saved are the epoch when the training is interrupted, the latest recorded training loss, and the external torch.nn.Embedding layer.
Generally use a dictionary to save these different parts, and then use torch.save() to serialize the dictionary. The PyTorch convention is to use the .tar file extension to save such checkpoints, and the file extension to use .tar. Saving in this way is 2-3 times larger than saving the model file separately.
To load the project, first initialize the model and optimizer, and then use torch.load() to load the local dictionary (checkpoint).
If you do inference, you must call model.eval() to set the dropout and batch normalization layers to evaluation mode. Failure to do so will produce inconsistent inference results. If you want to resume training, call model.train() to ensure that these layers are in training mode.

Load the parameters of different models warmstarting

Save:

torch.save(modelA.state_dict(), PATH)

load:

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

When transferring learning or training a new complex model, partial loading of the model or partial loading of the model is a common solution. Using the trained parameters, even if there are only a few available parameters, will help the hot start training process, and it is hoped that the model can be converged faster than training from scratch.
Either loading from state_dict that lacks some keys, or loading using state_dict with more keys than the model to be loaded, you only need to set the strict parameter to False in the load_state_dict() function to ignore unmatched keys .
If you want to load parameters from one layer to another, but some keys do not match, you only need to change the name of the parameter key in the state_dict to be loaded to match the key in the model to be loaded.

Save/load models across devices (CPU and GPU)

The model is saved on the GPU and loaded to the CPU

Save

torch.save(model.state_dict(), PATH)

load:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

The model is saved on the GPU and loaded to the GPU

Save:

torch.save(model.state_dict(), PATH)

load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

Be sure to use .to(torch.device('cuda')) to transfer all input model data to the GPU. Please note that calling my_tensor.to(device) will return a new copy of my_tensor on the GPU. It will not overwrite my_tensor. Therefore, remember to manually override the tensor: my_tensor = my_tensor.to(torch.device('cuda')).

Of course, there are also training GPUs on the CPU to load, but in this case, there is no operation.

The content comes from the official website of pytorch .
You must have the heart to read the official website documents, break the fear of difficulty, and then find that it is really easy to read the doc.

Guess you like

Origin blog.csdn.net/yanghao201607030101/article/details/110947689