super(MLP, self).__init__() 的含义

class MLP(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()

在读论文源码时,发现如上这个比较拗口的语法点。

其实意思很简单,首先找到MLP的父类(这里是类nn.Module),然后把类MLP的对象self转换为类nn.Module的对象,然后“被转换”的类nn.Module对象调用自己的_init_函数

这是对继承自父类的属性进行初始化。而且是用父类的初始化方法来初始化继承的属性。

也就是说,子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化。

当然,如果初始化的逻辑与父类的不同,不使用父类的方法,自己重新初始化也是可以的。

class MLP(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
         '''
            num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
            input_dim: dimensionality of input features
            hidden_dim: dimensionality of hidden units at ALL layers
            output_dim: number of classes for prediction
         '''

         super(MLP, self).__init__()

         self.linear_or_not = True #default is linear model
         self.num_layers = num_layers

         if num_layers < 1:
             raise ValueError("number of layers should be positive!")
         elif num_layers == 1:
             #Linear model
             self.linear = nn.Linear(input_dim, output_dim)
         else:
             #Multi-layer model
             self.linear_or_not = False
             self.linears = torch.nn.ModuleList()
             self.batch_norms = torch.nn.ModuleList()
 
             self.linears.append(nn.Linear(input_dim, hidden_dim))
             for layer in range(num_layers - 2):
                 self.linears.append(nn.Linear(hidden_dim, hidden_dim))
             self.linears.append(nn.Linear(hidden_dim, output_dim))

             for layer in range(num_layers - 1):
                 self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))

    def forward(self, x):
        if self.linear_or_not:
            #If linear model
            return self.linear(x)
        else:
            #If MLP
            h = x
            for layer in range(self.num_layers - 1):
                h = F.relu(self.batch_norms[layer](self.linears[layer](h)))
            return self.linears[self.num_layers - 1](h)

猜你喜欢

转载自blog.csdn.net/qq_36936730/article/details/113977503