在Pytorch中构建网络时,nn.Sequential容器只能传入继承自nn.Module的模块(参数),搭建一个流水线式的神经网络。
就一个简单的打平flatten操作来说,比如要将一个(N,C,H,W)的输入打平为 (N , C*H*W),其实就是一行简单的 x = x.view(x,size(0), -1)
但是这句话不能单独加入到nn.Sequential容器中。 而nn.Module中也确实没这个打平flatten操作的子模块。
所以就必须自定义一个,便于直接放入nn.Sequential容器中。
代码实现很简单。
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self, input):
return input.view(input.size(0), -1)
那么在使用的时候就可以直接作为参数放入nn.Sequential容器中。
net = nn.Sequential(
nn.Conv2d(1,16,stride=1,padding=1),
nn.MaxPool2d(2,2),
Flatten(),# 这里是自己实现的继承自nn.Modeules的子类
nn.Linear(xxx,xx))