一些关于Torch7的记录

版权声明:本文为博主原创文章,如未特别声明,均默认使用CC BY-SA 3.0许可。 https://blog.csdn.net/Geek_of_CSDN/article/details/80834038

有时想要在一台有GPU的机器上(前提是这台机器已经装好了CUDA)训练好模型之后将模型转到CPU型的,这样就可以在没有GPU的机器上(或者没装cuda加速的机器)导入这个模型了。但是可能会遇到奇怪的错误,这里就记录一些贫僧遇到的奇怪的错误。

奇怪错误之读取模型

在进入了torch的交互命令行环境(就是用th来进入的那个环境)之后,如果发现用m = torch.load('x.t7')遇到了这种unknown Torch class <nn.gModule>
stack traceback:
[C]: in function 'error'
错误的话,那么很有可能没有导入全要导入的包,例如这里就少导入了require 'nngraph'包,导入之后就可以读取模型了。

奇怪错误之nil什么什么的

使用m = torch.load('x.t7')来读取模型(t7文件,鬼知道里面存了什么。。。)成功后,使用m = m:float()来转化模型的时候,如果遇到了这个错误:

attempt to call method 'float' (a nil value) 。。。后面略

那么就是时候用torch.type(m)来确定这个模型的类型了的说。通常会发现模型的类型不是名字含有tensor的类型(并且很有可能是nil类型,毕竟错误信息都提示了)。遇到这种情况就要用type(m)来确定这个东东到底是什么东西,例如贫僧在发动了type(m)技能之后发现m居然是个lua的table!

这种情况通常都是因为原模型训练者为了方便把模型(nn.gModule类型)作为table的一部分,和其他一些附加信息(例如模型的设置、作者之类的)一起存在了这个.t7文件里面。

那么怎么提取出真正的模型呢?
首先要做的就是确定这个读取到的table里面到底有哪些键值:

for key, value in pairs(protos) do
    print(key)
end

上面的命令也是直接在交互环境下敲并运行就可以的了。
得到了键值之后可以依次用torch.type(xxx)来确定这个内容是什么类型的,例如有一个键值是doc那么就用torch.type(m.doc),或者torch.type(m["doc"]),两个指令其实是一个意思。

如果全都不是的话可能原训练者很逗逼地把模型藏在了table的table里面。。。这时就在table里面找到的talbe上重复上面的步骤吧。。。

通过这种方式找到真正的模型之后(就是nn.gModule类型的,可能也有别的类型,贫僧还是Torch7小白,不太清楚的说)再用回xxx:float()转换,就可以了。例如m.doc就是那部分模型:

model_to_save = m.doc:float()
torch.save('model.t7', model_to_save)

这样就可以保存好已经转化成cpu型的模型了。

猜你喜欢

转载自blog.csdn.net/Geek_of_CSDN/article/details/80834038