关于Pytorch 0.3 nn.Module的子类,前向传播过程的问题

先上代码:

class ft_net(nn.Module):

    def __init__(self, class_num ):
        super(ft_net, self).__init__()
        model_ft = models.resnet50(pretrained=True)
        # avg pooling to global pooling
        model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.model = model_ft
        self.classifier = ClassBlock(4096, 512)

    def forward(self, x,y):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = torch.squeeze(x)
        print(numpy.shape(x))

当dataloader的batch_size大于1时,执行了打印语句,但是当batch_size等于1时,却没有执行打印语句,再对张量进行观察,当batch_size大于1时,x是[n,1024],但是当batch_size等于1时,却是[1024],
注意!!这里不是[1,1024],这是张量维数的变化,当你代码中使用.cat() .view()等函数时,这将会报错。
之所以出现这个原因,是加入了torch.squeeze(x),squeeze将输入张量形状中的1 去除并返回,所以 一旦batch_size为1,就把4D张量第一维给抹去了

发布了55 篇原创文章 · 获赞 238 · 访问量 21万+

猜你喜欢

转载自blog.csdn.net/zkp_987/article/details/81867158