模型参数加载,权重参数加载

    loop:
        model.train()    # 切换至训练模式
        train……
        model.eval()    # 验证模式
        with torch.no_grad():
            Evaluation
    end loop

加载 模型、参数:

    def load_network(net, load_path, strict=False, param_key='params'):
        '''
        param net: 你的模型
        param load_path: 你想要加载的state_dict的路径
        '''
        # 拿到模型的state_dict包括原始参数
        net_dict = net.state_dict()
        # 拿到新参数的state_dict
        load_net = torch.load(load_path)
 
        # 根据size判断是否加载权重
        for k, v in load_net.items():#参数循环
            if v.size() == net_dict[k].size():
                net_dict[k] = v#模型
		#模型加载新参数
        net.load_state_dict(net_dict, strict=strict)
        return net

#它的逻辑主要是判断你的模型的state_dict(net_dict)和预训练权重(load_net)他们对应的layer,所对应的tensor的size是否一致,一致则导入,不一致则不导入。

存 参数:

cls = model.state_dict()
 for k, v in cls.items():
 	
#cls_models = os.path.join('cls_head_checkpoint_best.pth')  
#new_state_dict = {k: torch.load(cls_models)['head'][k] for k,v in state_dict.items()}# 返回字典{1:XXX,2:XXX} 将参数字典[head][k]取出

cls_dict = cls_head.state_dict() 
state = {
    
    				
            'epoch' :epoch+1,
            'head' :model.state_dict['head'][k] for k,v in cls_dict.items(),
            'optimizer' : optimizer.state_dict(),
        }

torch.save(state, file_path)

模型与模型字典的关系:

#encoding:utf-8
 
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass,self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.pool=nn.MaxPool2d(2,2)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
 
    def forward(self,x):
        x=self.pool(F.relu(self.conv1(x)))
        x=self.pool(F.relu(self.conv2(x)))
        x=x.view(-1,16*5*5)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x
 
def main():
    # Initialize model
    model = TheModelClass()
 
    #Initialize optimizer
    optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
    #print model's state_dict
    print('Model.state_dict:')
    for param_tensor in model.state_dict():#用来可视化 模型字典
        #打印 key value字典
        print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
    #print optimizer's state_dict
    print('Optimizer,s state_dict:')
    for var_name in optimizer.state_dict():
        print(var_name,'\t',optimizer.state_dict()[var_name])
 
 
 
if __name__=='__main__':
    main()
 
 
Output:
-----------------------------------------------------------------------------------------
Model.state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Optimizer,s state_dict:
state 	 {
    
    }
param_groups 	 [{
    
    'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [367949288, 367949432, 376459056, 381121808, 381121952, 381122024, 381121880, 381122168, 381122096, 381122312]}]
-----------------------------------------------------------------------------------------

猜你喜欢

转载自blog.csdn.net/weixin_44040169/article/details/130755030
今日推荐