torch 踩坑小结

近期使用torch写神经网络的框架,初次接触,踩了不少坑,在这里总结一下。

tensor type

torch中tensor是有类型的,默认是DoubleTensor:

th>torch.getdefaulttensortype()
torch.DoubleTensor

除此之外,还有FloatTensor, CudaTensor, IntTensor, CharTensor等,可以使用setdefaulttensortype设置缺省的tensortype,通常情况下,如果使用cuda进行网络的训练,就可以将缺省的Tensor Type设置为CudaTensor。但是这里有一个坑,就是CudaTensor相比其他的Tensor缺少一些成员函数(我们暂且叫成员函数),比如image。

th> a=torch.FloatTensor():resize(sample_temp:size()):copy(sample_temp)
                                                                      [0.0001s] 
th> a.i
a.image.         a.indexCopy(     a.isContiguous(  a.isSize(
a.index(         a.indexFill(     a.isSameSizeAs(  
a.indexAdd(      a.int(           a.isSetTo(       
th> a=torch.Tensor():resize(sample_temp:size()):copy(sample_temp)
                                                                      [0.0003s] 
th> a.i
a.index(         a.indexFill(     a.isContiguous(  a.isSize(
a.indexAdd(      a.int(           a.isSameSizeAs(  
a.indexCopy(     a.inverse(       a.isSetTo(       
th> a=torch.FloatTensor():resize(sample_temp:size()):copy(sample_temp)
                                                                      [0.0001s] 

在defaulttensortype是CudaTensor时调用image.save函数会出现如下错误:

th> image.save('1.jpg',sample_temp);
.../lijiguo/code/torch/install/share/lua/5.1/image/init.lua:166: attempt to index field 'image' (a nil value)
stack traceback:
    .../lijiguo/code/torch/install/share/lua/5.1/image/init.lua:166: in function 'clampImage'
    .../lijiguo/code/torch/install/share/lua/5.1/image/init.lua:267: in function 'saver'
    .../lijiguo/code/torch/install/share/lua/5.1/image/init.lua:457: in function 'save'
    [string "image.save('1.jpg',sample_temp);"]:1: in main chunk
    [C]: in function 'xpcall'
    .../lijiguo/code/torch/install/share/lua/5.1/trepl/init.lua:679: in function 'repl'
    ...code/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:204: in main chunk
    [C]: at 0x00406620  
                                                                      [0.0004s] 

还有就是network输入的TensorType应该与network的TensorType一致,如果network是在GPU上的,那么他就是CudaTensor,这时候应该讲input vector设置为CudaTensor,否则会出现以下错误(github上也有人遇到这个问题:https://github.com/torch/torch7/issues/1087

th> input_vector=torch.FloatTensor(1,100);
                                                                      [0.0000s] 
th> sample=model.G:forward(input_vector)
...ijiguo/code/torch/install/share/lua/5.1/nn/Container.lua:67: 
In 1 module of nn.Sequential:
...1/lijiguo/code/torch/install/share/lua/5.1/nn/Linear.lua:66: invalid arguments: CudaTensor number CudaTensor number FloatTensor CudaTensor 
expected arguments: *CudaTensor~2D* [CudaTensor~2D] [float] CudaTensor~2D CudaTensor~2D | *CudaTensor~2D* float [CudaTensor~2D] float CudaTensor~2D CudaTensor~2D
stack traceback:
    [C]: in function 'addmm'
    ...1/lijiguo/code/torch/install/share/lua/5.1/nn/Linear.lua:66: in function <...1/lijiguo/code/torch/install/share/lua/5.1/nn/Linear.lua:53>
    [C]: in function 'xpcall'
    ...ijiguo/code/torch/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
    ...jiguo/code/torch/install/share/lua/5.1/nn/Sequential.lua:44: in function 'forward'
    [string "sample=model.G:forward(input_vector)"]:1: in main chunk
    [C]: in function 'xpcall'
    .../lijiguo/code/torch/install/share/lua/5.1/trepl/init.lua:679: in function 'repl'
    ...code/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:204: in main chunk
    [C]: at 0x00406620

WARNING: If you see a stack trace below, it doesn't point to the place where this error occurred. Please use only the one above.
stack traceback:
    [C]: in function 'error'
    ...ijiguo/code/torch/install/share/lua/5.1/nn/Container.lua:67: in function 'rethrowErrors'
    ...jiguo/code/torch/install/share/lua/5.1/nn/Sequential.lua:44: in function 'forward'
    [string "sample=model.G:forward(input_vector)"]:1: in main chunk
    [C]: in function 'xpcall'
    .../lijiguo/code/torch/install/share/lua/5.1/trepl/init.lua:679: in function 'repl'
    ...code/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:204: in main chunk
    [C]: at 0x00406620  
                                                                      [0.0009s] 

lua table index

lua的table索引从1开始,切记切记切记!!!
否则将出现以下错误:

th> image.save('1.jpg',sample[0]);
bad argument #2 to '?' (out of range)
stack traceback:
    [C]: at 0x7f1d34138bd0
    [C]: in function '__index'
    [string "image.save('1.jpg',sample[0]);"]:1: in main chunk
    [C]: in function 'xpcall'
    .../lijiguo/code/torch/install/share/lua/5.1/trepl/init.lua:679: in function 'repl'
    ...code/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:204: in main chunk
    [C]: at 0x00406620  
                                                                      [0.0001s] 

猜你喜欢

转载自blog.csdn.net/smallflyingpig/article/details/78762645
今日推荐