版权声明:转载注明出处 https://blog.csdn.net/york1996/article/details/83018581
vgg16=torchvision.models.vgg16(pretrained=True)
class NET(nn.Module):
def __init__(self):
super(NET, self).__init__()
self.features=vgg16.features()
self.fc=vgg16.classifier()
def forward(self, x):
x=self.features(x)
x=self.fc(x)
return x
其中在构造函数__init__中,self.features=vgg16.features()相当于调用了vgg16的forward方法来传播数据,但是原本是打算继承VGG模型的,所以只需要把vgg16.features后面的括号去掉就可以了。。
正确的写法:
vgg16=torchvision.models.vgg16(pretrained=True)
class NET(nn.Module):
def __init__(self):
super(NET, self).__init__()
self.features=vgg16.features
self.fc=vgg16.classifier
def forward(self, x):
x=self.features(x)
x=self.fc(x)
return x