torch container && Table Containers

1
container直接对输入的Tensor进行操作,而Table Containers则是对输入的table进行操作。

2

1) 最常见的container是nn.Sequential,他的目的是将模块串联起来,是串联!

mlp = nn.Sequential()
mlp:add(nn.Linear(10, 25)) -- Linear module (10 inputs, 25 hidden units)
mlp:add(nn.Tanh())         -- apply hyperbolic tangent transfer function on each hidden units
mlp:add(nn.Linear(25, 1))  -- Linear module (25 inputs, 1 output)

2) 第二个常见的container是Parallel(input Dimension,outputDimension),他的意思是将输入沿着input Dimension切开,将他的第i个child应用在切开的第i份数据上,然后最后沿着outputDimension concat在一起

mlp = nn.Parallel(2,1);   -- Parallel container will associate a module to each slice of dimension 2
                           -- (column space), and concatenate the outputs over the 1st dimension.

mlp:add(nn.Linear(10,3)); -- Linear module (input 10, output 3), applied on 1st slice of dimension 2
mlp:add(nn.Linear(10,2))  -- Linear module (input 10, output 2), applied on 2nd slice of dimension 2

                                  -- After going through the Linear module the outputs are
                                  -- concatenated along the unique dimension, to form 1D Tensor
> mlp:forward(torch.randn(10,2)) -- of size 5.
-0.5300
-1.1015
 0.7764
 0.2819
-0.6026

将输入randn(10,2)沿着第二维切开,产生两个randn(10,1),然后将对应的(10,1)应用在对应的子模块上,也即第一个(10,1)应用在nn.Linear(10,3)上面,第二个(10,1)应用在nn.Linear(10,2)上面,最后产生一个大小为(3,1)的结果,一个大小为(2,1)的结果,然后沿着第一维串联在一起
3)concat之前说过了

container的性质:
1)container是从module继承而来,module有的性质它都有
2)get(index),获得container中index处的模块
3)size(),获得container中module的数量

对于nngraph构建的网络的self.model:listModules()和self.model:get(iii)的区别,nngraph是一个container,
self.model:get()是获得nngraph中的Node
self.model:listModules()则是获得整个model的所有元素,对于一个Node可以继续肢解,直到肢解成最小的单位container或者单一的节点,所以总体来说第二种方式会获得更多的肢解信息,因为他的目的是将网络肢解成一个个最小的模块,而get()函数相对于nngraph来讲则只是获得相应的nngraph Node,对于Node它不再继续分解,如果一个nn.Sequential包含了很多内部子模块,但是由于它是一个Node,所以将不会进行肢解。

        threshold_nodes, container_nodes = self.model:findModules('cudnn.SpatialConvolution')
        for i = 1,#threshold_nodes do
          print(threshold_nodes[i])
          print(container_nodes[i]) 
        end

self.model:findModules则会返回查找的module,并且找出对应的container,对于没有container作为父节点的,自己就可以看作是container了,这个函数是查找网络中所有的module,所以必须对网络进行完全的肢解才可以,与self.model:listModules()是对应的,都是container父类module类拥有的性质,它自己继承过来了

3 一般打印网络输出就是用

        for iii = 1,self.model:size(),1 do 
            print(iii,self.model:get(iii))
        end

找到对应模块的标号也即iii,打印相应内容即可

猜你喜欢

转载自blog.csdn.net/u013548568/article/details/79937939
今日推荐