Torch nn.Concat,nn.ConcatTable,nn.JoinTable.

1、根据字面意思,nn.Concat,nn.ConcatTable都是将输出concat在一起,那么二者有什么区别呢。concat将最终的concat的结果变成一个整体,ConcatTable将输出保存在一个table里面。

mlp=nn.Concat(1);
mlp:add(nn.SpatialConvolution(3,64,7,7,2,2,3,3))
mlp:add(nn.SpatialConvolution(3,64,7,7,2,2,3,3))
print(mlp:forward(torch.randn(2,3,256,256)))          --输出结果为4x64x128x128
mlp=nn.ConcatTable();
mlp:add(nn.SpatialConvolution(3,64,7,7,2,2,3,3))
mlp:add(nn.SpatialConvolution(3,64,7,7,2,2,3,3))
print(mlp:forward(torch.randn(2,3,256,256)))

第二段代码输出结果为
这里写图片描述
二者的共同点则是接受同一个输入,对多个输出进行操作,所以这两个操作无法进行将多个输入连在一起的操作,如果想要执行这个操作,就用到了nn.JoinTable().

2、JoinTable()将多个输入concat在一起,并且生成一个整体

h1 = nn.SpatialConvolution(3,3,7,7,2,2,3,3)()
h2 = nn.SpatialConvolution(3,64,7,7,2,2,3,3)(h1)
h3 = nn.SpatialConvolution(3,64,7,7,2,2,3,3)(h1)
h4 = nn.JoinTable(2)({h3,h2})
mlp = nn.gModule({h1}, {h4})

x = torch.rand(2,3,256,256)
output = mlp:forward(x)
print(output:size())

输出为2x128x64x64

3、/gmodule.lua:135: expecting only one start
遇见这种错误,查看这句话input和output是否都是一个table

local model = nn.gModule({inp}, {tmpOut})

4、nn.sigmoid()之前是不能加bn和relu的,否则损失极大无法下降

猜你喜欢

转载自blog.csdn.net/u013548568/article/details/79859238
今日推荐