[ pytorch ] ——基本使用:(4) 模型取出中间层输出

######### 模型定义 #########
class MyModel(nn.Module):
    def __init__(self):   # input the dim of output fea-map of Resnet:
        super(MyModel, self).__init__()
        
        BackBone = models.resnet50(pretrained=True)
        
        add_block = []
        add_block += [nn.Linear(2048, 512)]
        add_block += [nn.LeakyReLU(inplace=True)]
        add_block = nn.Sequential(*add_block)
        add_block.apply(weights_init_xavier)

        self.BackBone = BackBone
        self.add_block = add_block


    def forward(self, input):   # input is 2048!

        ##### 关键步骤 #####
        for name, midlayer in self.BackBone._modules.items():
            x = midlayer(x)
            print(name)
            if name == 'layer2':    # 取出resnet中的layer2层输出
                break
        ##### 关键步骤 #####

        x = self.BackBone(input)
        x = self.add_block(x)

        return x
##############################


# debug model structure
net = MyModel(751)

print(net)
input = Variable(torch.FloatTensor(8, 3, 256, 128))
print(input.shape)
output = net(input)
print('net output size:')
print(output.shape)

猜你喜欢

转载自blog.csdn.net/jdzwanghao/article/details/83313057