pytorch加载预训练模型参数的方式

1.直接使用默认程序里的下载方式,往往比较慢;

2.通过修改源代码,使得模型加载已经下载好的参数,修改地方如下:

通过查找自己代码里所调用网络的类,使用pycharm自带的函数查找功能(ctrl+鼠标左键),查看此网络的加载方法,修改model.load_state_dict()函数。

例如:vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))

猜你喜欢

转载自www.cnblogs.com/ywheunji/p/10607422.html