pytroch forward() missing 1 required positional argument: 'input'的一个可能原因

版权声明:转载注明出处 https://blog.csdn.net/york1996/article/details/83018581

vgg16=torchvision.models.vgg16(pretrained=True)
class NET(nn.Module):
    def __init__(self):
        super(NET, self).__init__()
        self.features=vgg16.features
()
        self.fc=vgg16.classifier()
    def forward(self, x):
        x=self.features(x)
        x=self.fc(x)
        return x

其中在构造函数__init__中,self.features=vgg16.features()相当于调用了vgg16的forward方法来传播数据,但是原本是打算继承VGG模型的,所以只需要把vgg16.features后面的括号去掉就可以了。。

正确的写法:

vgg16=torchvision.models.vgg16(pretrained=True)
class NET(nn.Module):
    def __init__(self):
        super(NET, self).__init__()
        self.features=vgg16.features
        self.fc=vgg16.classifier
    def forward(self, x):
        x=self.features(x)
        x=self.fc(x)
        return x

猜你喜欢

转载自blog.csdn.net/york1996/article/details/83018581