pytorch模型存取与加载

# 存取模型

PATH='./xxx.xx' # 模型名字
torch.save(net.state_dict(),PATH) # 保存,net为需要保存的网络
pretrained_net = torch.load(PATH) # 读取
net2 = Net() # Net()为保存的模型同结构的模型
net2.load_state_dict(pretrained_net) # 加载权重
发布了22 篇原创文章 · 获赞 0 · 访问量 4427

猜你喜欢

转载自blog.csdn.net/Yolo_C/article/details/104103810