[ pytorch ] ——基本使用:(5) 模型并联

模型并联要注意的是:即使并联的模块是一样的,也要用不同的变量来定义,不然model.parameters里面只会出现一次该模块,而不是并联的全部模块。

class MyModel(nn.Module):     # Resnet50 + Encoder_Decoder
    def __init__(self,class_num=2):   # input the dim of output fea-map of Resnet: 7*7*2048
        super(MyModel, self).__init__()
        
        ### 分别用不同 变量名 定义每个并联模块 ###
        for iii in range(3):
            locals()["block" + str(iii + 1)] = []  # Variable name: block1、block2、block3
            locals()["block" + str(iii + 1)] += [torch.nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3)]
            locals()["block" + str(iii + 1)] += [torch.nn.AdaptiveAvgPool2d((1,1))]
            locals()["block" + str(iii + 1)] = torch.nn.Sequential(*(locals()["block" + str(iii+1)]))
        ### 分别用不同 变量名 定义每个并联模块 ###

        self.Block1 = locals()["block" + str(1)]
        self.Block2 = locals()["block" + str(2)]
        self.Block3 = locals()["block" + str(3)]

        self.classifier = ClassBlock(10, class_num=class_num)

    def forward(self, input):   # input is 7*7*2048!

        x = input
        
        ### 并联forward ###
        x1 = self.Block1(x)
        x2 = self.Block2(x)
        x3 = self.Block3(x)
        x1 = torch.squeeze(x1)
        x2 = torch.squeeze(x2)
        x3 = torch.squeeze(x3)
        ### 并联forward ###
        
        feature = torch.add(x1, x2)  # feature fusion
        feature = torch.add(feature, x3)

        feature = self.classifier(feature)

        return feature

################
#    main()
# --------------

net = MyModel()

print(list(net.named_parameters()))  # 可以查到net.Block1、 net.Block2、 net.Block3

猜你喜欢

转载自blog.csdn.net/jdzwanghao/article/details/84197304