继承自nn.Module的自定义Flatten模块

在Pytorch中构建网络时,nn.Sequential容器只能传入继承自nn.Module的模块(参数),搭建一个流水线式的神经网络。

就一个简单的打平flatten操作来说,比如要将一个(N,C,H,W)的输入打平为 (N , C*H*W),其实就是一行简单的  x = x.view(x,size(0), -1)

但是这句话不能单独加入到nn.Sequential容器中。 而nn.Module中也确实没这个打平flatten操作的子模块。

所以就必须自定义一个,便于直接放入nn.Sequential容器中。

代码实现很简单。

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten,self).__init__()
    
    def forward(self, input):
        return input.view(input.size(0), -1)

那么在使用的时候就可以直接作为参数放入nn.Sequential容器中。

net = nn.Sequential(
                    nn.Conv2d(1,16,stride=1,padding=1),
                    nn.MaxPool2d(2,2),
                    Flatten(),# 这里是自己实现的继承自nn.Modeules的子类
                    nn.Linear(xxx,xx))

Guess you like

Origin blog.csdn.net/thequitesunshine007/article/details/121228332