pytorch 模型保存与加载

保存模型或权重参数的后缀问题:

pytorch保存数据的格式为.t7文件或者.pth文件,或者.pkl格式,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

来自:https://blog.csdn.net/weixin_43216883/article/details/89792312

两种方式:

(1)保存模型参数

#保存
torch.save( model.state_dict(), path)
#加载
the_model = CNN()
the_model.load_state_dict(torch.load(path))

这种方法在加载模型的时候,必须在代码中重新将CNN的结构重新建立一遍,才可以将保存好的参数(w和b)放在模型中,进行训练用。

下面是利用上述方法,加载保存好的模型的具体代码,实现随便一张图片的识别(利用的手写体的数据集),CNN结构是用的莫烦里面的结构。

from PIL import Image
import torch.nn as nn
import torch
import numpy as np

def img2vec_img(img):
    #将jpg等格式的图片转为向量
    im = Image.open(img)
    im = im.resize((28,28))
    tmp = np.array(im)
    #vec = tmp.ravel()
    return tmp

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,          
                kernel_size=5,              
                stride=1,                  
                padding=2,                
            ),                             
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(        
            nn.Conv2d(16, 32, 5, 1, 2),   
            nn.ReLU(),                     
            nn.MaxPool2d(2),
            nn.Dropout(0.1)                
        )
        self.out = nn.Linear(32 * 7 * 7, 10)  

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)           
        output = self.out(x)
        return output   

#模型加载
the_model = CNN()
the_model.load_state_dict(torch.load('C:\\Users\\happy\\Desktop\\train\\model.pth'))

def detect(path):
    #the_model.eval()
    test_picture = img2vec_img(path)
    data = torch.from_numpy(test_picture).type ( 'torch.FloatTensor' )
    data = torch.unsqueeze(data, 0) #给torch更加一个维度,以便于训练
    data = torch.unsqueeze(data, 0) #给torch更加一个维度,以便于训练
    the_model(data)
    _, pred = torch.max(the_model(data) , 1)
    #print(pred.int())
    print(pred.int())

 
detect('C:\\Users\\happy\\Desktop\\t.bmp')

(2)保存模型整体

#保存
torch.save(model,path)

#加载
the_model = torch.load(path)
发布了56 篇原创文章 · 获赞 29 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/foneone/article/details/91956584