图神经网络(三)GCN的变体与框架(5)GraphSAGE实战

图神经网络(三)GCN的变体与框架(5)GraphSAGE实战

3.5 GraphSAGE实战

 本节我们通过代码来介绍GraphSAGE以加深读者对相关知识的理解。如3.1节所介绍的,GraphSAGE包括两个方面,一是对于邻居的采样;二是对邻居的聚合操作 [0]

 首先来看下对邻居的采样方法,为了实现更高效地采样,可以将节点及其邻居存放在一起,即维护一个节点与其邻居对应关系的表。我们可以通过两个函数samplingmultihop_sampling来实现采样的具体操作。其中sampling是进行一阶采样,根据源节点采样指定数量的邻居节点,multihop_sampling则是利用sampling实现多阶采样的功能。如代码清单3-1所示:

代码清单3-1 对邻居节点进行多阶采样

def sampling(src_nodes, sample_num, neighbor_table):
    """根据源节点采样指定数量的邻居节点,注意使用的是有放回的采样;
    某个节点的邻居节点数量少于采样数量时,采样结果出现重复的节点
    
    Arguments:
        src_nodes {list, ndarray} -- 源节点列表
        sample_num {int} -- 需要采样的节点数
        neighbor_table {dict} -- 节点到其邻居节点的映射表
    
    Returns:
        np.ndarray -- 采样结果构成的列表
    """
    results = []
    for sid in src_nodes:
        # 从节点的邻居中进行有放回地进行采样
        res = np.random.choice(neighbor_table[sid], size=(sample_num, ))
        results.append(res)
    return np.asarray(results).flatten()


def multihop_sampling(src_nodes, sample_nums, neighbor_table):
    """根据源节点进行多阶采样
    
    Arguments:
        src_nodes {list, np.ndarray} -- 源节点id
        sample_nums {list of int} -- 每一阶需要采样的个数
        neighbor_table {dict} -- 节点到其邻居节点的映射
    
    Returns:
        [list of ndarray] -- 每一阶采样的结果
    """
    sampling_result = [src_nodes]
    for k, hopk_num in enumerate(sample_nums):
        hopk_result = sampling(sampling_result[k], hopk_num, neighbor_table)
        sampling_result.append(hopk_result)
    return sampling_result

 这样采样得到的结果仅是节点的ID,还需要根据节点ID去查询每个节点的特征,以进行聚合操作更新特征。

 下面根据如下平均加和(mean/sum)聚合算子公式(公式详见):
Agg sum = σ ( SUM { W h j + b ,   ∀ v j ∈ N ( v i ) } ) \text{Agg}^\text{sum}=σ(\text{SUM}\{W\boldsymbol{h}_j+\boldsymbol{b},\ ∀v_j∈N(v_i)\}) Aggsum=σ(SUM{ Whj+b, vjN(vi)})
 与池化(pooling)聚合算子公式(公式详见):
Agg pool = MAX { σ ( W h j + b ,   ∀ v j ∈ N ( v i ) } \text{Agg}^\text{pool}=\text{MAX}\{σ(W\boldsymbol{h}_j+\boldsymbol{b},\ ∀v_j∈N(v_i)\} Aggpool=MAX{ σ(Whj+b, vjN(vi)}
来实现邻居的聚合操作,计算的过程定义在forward函数中,输入neighbor_feature表示需要聚合的邻居节点的特征,它的维度为 N src × N neighbor × D in N_{\text{src}}×N_{\text{neighbor}}×D_{\text{in}} Nsrc×Nneighbor×Din ,其中 N src N_{\text{src}} Nsrc 表示源节点的数量, N neighbor N_{\text{neighbor}} Nneighbor 表示邻居节点的数量, D in D_{\text{in}} Din 表示输入的特征维度。将这些邻居节点的特征经过一个线性变换得到隐层特征,这样就可以沿着第 1 1 1 个维度进行聚合操作了,包括求和、均值和最大值,得到维度为 N src × D in N_{\text{src}}×D_{\text{in}} Nsrc×Din 的输出。如代码清单3-2所示:

代码清单3-2 邻居聚合

class NeighborAggregator(nn.Module):
    def __init__(self, input_dim, output_dim, 
                 use_bias=False, aggr_method="mean"):
        """聚合节点邻居

        Args:
            input_dim: 输入特征的维度
            output_dim: 输出特征的维度
            use_bias: 是否使用偏置 (default: {False})
            aggr_method: 邻居聚合方式 (default: {mean})
        """
        super(NeighborAggregator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.use_bias = use_bias
        self.aggr_method = aggr_method
        self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
        if self.use_bias:
            self.bias = nn.Parameter(torch.Tensor(self.output_dim))
        self.reset_parameters()
    
    def reset_parameters(self):
        init.kaiming_uniform_(self.weight)
        if self.use_bias:
            init.zeros_(self.bias)

    def forward(self, neighbor_feature):
        if self.aggr_method == "mean":
            aggr_neighbor = neighbor_feature.mean(dim=1)
        elif self.aggr_method == "sum":
            aggr_neighbor = neighbor_feature.sum(dim=1)
        elif self.aggr_method == "max":
            aggr_neighbor = neighbor_feature.max(dim=1)
        else:
            raise ValueError("Unknown aggr type, expected sum, max, or mean, but got {}"
                             .format(self.aggr_method))
        
        neighbor_hidden = torch.matmul(aggr_neighbor, self.weight)
        if self.use_bias:
            neighbor_hidden += self.bias

        return neighbor_hidden

 基于邻居聚合的结果对中心节点的特征进行更新。更新的方式是将邻居节点聚合的特征与经过线性变换的中心节点的特征进行求和或者级联,在经过一个激活函数,得到更新后的特征。如代码清单3-3所示:


代码清单3-3 SageGCN定义

class SageGCN(nn.Module):
    def __init__(self, input_dim, hidden_dim,
                 activation=F.relu,
                 aggr_neighbor_method="mean",
                 aggr_hidden_method="sum"):
        """SageGCN层定义

        Args:
            input_dim: 输入特征的维度
            hidden_dim: 隐层特征的维度,
                当aggr_hidden_method=sum, 输出维度为hidden_dim
                当aggr_hidden_method=concat, 输出维度为hidden_dim*2
            activation: 激活函数
            aggr_neighbor_method: 邻居特征聚合方法,["mean", "sum", "max"]
            aggr_hidden_method: 节点特征的更新方法,["sum", "concat"]
        """
        super(SageGCN, self).__init__()
        assert aggr_neighbor_method in ["mean", "sum", "max"]
        assert aggr_hidden_method in ["sum", "concat"]
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.aggr_neighbor_method = aggr_neighbor_method
        self.aggr_hidden_method = aggr_hidden_method
        self.activation = activation
        self.aggregator = NeighborAggregator(input_dim, hidden_dim,
                                             aggr_method=aggr_neighbor_method)
        self.weight = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.reset_parameters()
    
    def reset_parameters(self):
        init.kaiming_uniform_(self.weight)

    def forward(self, src_node_features, neighbor_node_features):
        neighbor_hidden = self.aggregator(neighbor_node_features)
        self_hidden = torch.matmul(src_node_features, self.weight)
        
        if self.aggr_hidden_method == "sum":
            hidden = self_hidden + neighbor_hidden
        elif self.aggr_hidden_method == "concat":
            hidden = torch.cat([self_hidden, neighbor_hidden], dim=1)
        else:
            raise ValueError("Expected sum or concat, got {}"
                             .format(self.aggr_hidden))
        if self.activation:
            return self.activation(hidden)
        else:
            return hidden

 基于前面定义的采样和节点特征更新方式,就可以实现3.1.3节介绍的计算节点嵌入的方法。下面定义了一个两层的模型,隐藏层节点数为 64 64 64 ,假设每阶采样节点数都为 10 10 10 ,那么计算中心节点的输出可以通过以下代码实现。其中前向传播时传入的参数 node_feature_list 是一个列表,其中第 0 0 0 个元素表示源节点的特征,其后的元素表示每阶采样得到的节点的特征。如代码清单3-4所示:


代码清单3-4 GraphSage模型示例

class GraphSage(nn.Module):
    def __init__(self, input_dim, hidden_dim,
                 num_neighbors_list):
        super(GraphSage, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_neighbors_list = num_neighbors_list
        self.num_layers = len(num_neighbors_list)
        self.gcn = nn.ModuleList()
        self.gcn.append(SageGCN(input_dim, hidden_dim[0]))
        for index in range(0, len(hidden_dim) - 2):
            self.gcn.append(SageGCN(hidden_dim[index], hidden_dim[index+1]))
        self.gcn.append(SageGCN(hidden_dim[-2], hidden_dim[-1], activation=None))

    def forward(self, node_features_list):
        hidden = node_features_list
        for l in range(self.num_layers):
            next_hidden = []
            gcn = self.gcn[l]
            for hop in range(self.num_layers - l):
                src_node_features = hidden[hop]
                src_node_num = len(src_node_features)
                neighbor_node_features = hidden[hop + 1] \
                    .view((src_node_num, self.num_neighbors_list[hop], -1))
                h = gcn(src_node_features, neighbor_node_features)
                next_hidden.append(h)
            hidden = next_hidden
        return hidden[0]


参考文献:

[0] 刘忠雨, 李彦霖, 周洋.《深入浅出图神经网络: GNN原理解析》.机械工业出版社.

猜你喜欢

转载自blog.csdn.net/weixin_43360025/article/details/124468264