图神经网络(四)图分类(3)图分类实战

图神经网络(四)图分类(3)图分类实战

4.3 图分类实战

 本节我们通过代码来实现基于 自注意力的池化机制(Self-Attention Pooling)。这种方法的思路是通过图卷积从图中自适应地学习到节点的重要性。[0] 具体来说,使用第1章中定义的图卷积方式,可以为每个节点赋予一个重要性分数,如下式所示:
Z = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 X Θ a t t ) Z=σ(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}XΘ_{att}) Z=σ(D~1/2A~D~1/2XΘatt) 其中 σ σ σ 表示激活函数, A ~ \tilde{A} A~ 表示增加了自连接的邻接矩阵, X X X 表示节点的特征, Θ a t t ∈ R N × l Θ_{att}∈R^{N×l} ΘattRN×l 是权重参数,这也是自注意力池化层中唯一引入的参数。关于上述图卷积的实现,请参考1.6节中具体的代码实现,这里不再赘述。

 根据节点重要度分数和拓扑结构可以进行池化操作,如下式所示:
i = top-rank ( z , k N ) \boldsymbol{i}=\text{top-rank}(\boldsymbol{z},kN) i=top-rank(z,kN)舍弃掉不太重要的节点,对邻接矩阵和节点特征进行更新,得到池化结果。首先来看如何根据上式实现节点的选择。代码片段如代码清单4-1所示:


代码清单4-1 根据节点重要度分数进行池化操作

def top_rank(attention_score, graph_indicator, keep_ratio):
    """基于给定的attention_score, 对每个图进行pooling操作.
    为了直观体现pooling过程,我们将每个图单独进行池化,最后再将它们级联起来进行下一步计算
    
    Arguments:
    ----------
        attention_score:torch.Tensor
            使用GCN计算出的注意力分数,Z = GCN(A, X)
        graph_indicator:torch.Tensor
            指示每个节点属于哪个图
        keep_ratio: float
            要保留的节点比例,保留的节点数量为int(N * keep_ratio)
    """
    # TODO: 确认是否是有序的, 必须是有序的
    graph_id_list = list(set(graph_indicator.cpu().numpy()))
    mask = attention_score.new_empty((0,), dtype=torch.bool)
    for graph_id in graph_id_list:
        graph_attn_score = attention_score[graph_indicator == graph_id]
        graph_node_num = len(graph_attn_score)
        graph_mask = attention_score.new_zeros((graph_node_num,),
                                                dtype=torch.bool)
        keep_graph_node_num = int(keep_ratio * graph_node_num)
        _, sorted_index = graph_attn_score.sort(descending=True)
        graph_mask[sorted_index[:keep_graph_node_num]] = True
        mask = torch.cat((mask, graph_mask))
    
    return mask

 函数top_rank接收3个参数,一是使用GCN得到的节点重要度分数attention_score;二是知识每个节点属于哪个图的参数graph_indicator,这里我们将多个需要分类的图放在一起进行批处理,以提高运算速度,graph_indicator里面包含的数据为 [ 0 , 0 , … , 0 , 1 , 1 , … , 1 , 2 , 2 , … , 2 , … ] [0,0,…,0,1,1,…,1,2,2,…,2,…] [0,0,,0,1,1,,1,2,2,,2,] 。需要注意的是,graph_indicator的标识值需要进行升序排列,同时属于同一个图的节点需要连续排列在一起;三是超参数keep_ratio,表示每次池化需要保留的节点比例,这是针对单个图而言的,不是整个批处理中所有的数据。实现逻辑上根据graph_indicator依次遍历每个图,取出该图对应的注意力分数,并进行排序得到要保留的节点索引,将这些位置的索引设置为True,得到每个子图节点的掩码向量。将所有图的掩码拼接在一起得到批处理中所有节点的掩码,作为函数的返回值。

 接下来,根据得到的节点掩码对图结构和特征进行更新。图结构的更新是根据掩码向量对邻接矩阵进行索引,得到保留节点之间的邻接矩阵,再进行归一化,作为后续GCN层的输入。因此我们定义两个功能函数normalization(adjacency)filter_adjacency(adjacency, mask)。其中normalization(adjacency)接收一个scipy.sparse.csr_matrix,对它进行规范化并转换为torch.sparse.FloatTensor。另一个函数filter_adjacency(adjacency, mask)接收两个参数,一个是池化之前的邻接矩阵adjacency,它的类型为torch.sparse.FloatTensor,另一个函数top_rank输出的节点的掩码mask。为了利用scipy.sparse提供的索引切片,这里将池化之前的的adjacency转换为scipy.sparse.csr_matrix,然后通过掩码mask进行切片,得到池化后的节点之间的邻接关系,然后再使用函数normalization进行规范化,作为下一层图卷积的输入。如代码清单4-2所示:


代码清单4-2 图结构更新

def normalization(adjacency):
    """计算 L=D^-0.5 * (A+I) * D^-0.5,

    Args:
        adjacency: sp.csr_matrix.

    Returns:
        归一化后的邻接矩阵,类型为 torch.sparse.FloatTensor
    """
    adjacency += sp.eye(adjacency.shape[0])    # 增加自连接
    degree = np.array(adjacency.sum(1))
    d_hat = sp.diags(np.power(degree, -0.5).flatten())
    L = d_hat.dot(adjacency).dot(d_hat).tocoo()
    # 转换为 torch.sparse.FloatTensor
    indices = torch.from_numpy(np.asarray([L.row, L.col])).long()
    values = torch.from_numpy(L.data.astype(np.float32))
    tensor_adjacency = torch.sparse.FloatTensor(indices, values, L.shape)
return tensor_adjacency

def filter_adjacency(adjacency, mask):
    """根据掩码mask对图结构进行更新
    
    Args:
        adjacency: torch.sparse.FloatTensor, 池化之前的邻接矩阵
        mask: torch.Tensor(dtype=torch.bool), 节点掩码向量
    
    Returns:
        torch.sparse.FloatTensor, 池化之后归一化邻接矩阵
    """
    device = adjacency.device
    mask = mask.cpu().numpy()
    indices = adjacency.coalesce().indices().cpu().numpy()
    num_nodes = adjacency.size(0)
    row, col = indices
    maskout_self_loop = row != col
    row = row[maskout_self_loop]
    col = col[maskout_self_loop]
    sparse_adjacency = sp.csr_matrix((np.ones(len(row)), (row, col)),
                                     shape=(num_nodes, num_nodes), dtype=np.float32)
    filtered_adjacency = sparse_adjacency[mask, :][:, mask]
    return normalization(filtered_adjacency).to(device)

 利用上面介绍的这些功能函数,就可以实现 自注意力层 ,该层的输出为池化之后的特征、节点属于哪个子图的表示以及规范化的邻接矩阵。如代码清单4-3所示:


代码清单4-3 基于自注意力机制的池化层

class SelfAttentionPooling(nn.Module):
    def __init__(self, input_dim, keep_ratio, activation=torch.tanh):
        super(SelfAttentionPooling, self).__init__()
        self.input_dim = input_dim
        self.keep_ratio = keep_ratio
        self.activation = activation
        self.attn_gcn = GraphConvolution(input_dim, 1)
    
    def forward(self, adjacency, input_feature, graph_indicator):
        attn_score = self.attn_gcn(adjacency, input_feature).squeeze()
        attn_score = self.activation(attn_score)
        
        mask = top_rank(attn_score, graph_indicator, self.keep_ratio)
        hidden = input_feature[mask] * attn_score[mask].view(-1, 1)
        mask_graph_indicator = graph_indicator[mask]
        mask_adjacency = filter_adjacency(adjacency, mask)
        return hidden, mask_graph_indicator, mask_adjacency

 要进行图分类,还需要全局的池化操作,它将节点数不同的图降维到同一纬度。常见的全局池化方式包括取最大值或均值。下面是这两种方式的实现代码,如代码清单3-4所示:

扫描二维码关注公众号,回复: 14215812 查看本文章

代码清单3-4 图读出机制

import torch_scatter
def global_max_pool(x, graph_indicator):
    num = graph_indicator.max().item() + 1
    return torch_scatter.scatter_max(x, graph_indicator, dim=0, dim_size=num)[0]


def global_avg_pool(x, graph_indicator):
    num = graph_indicator.max().item() + 1
    return torch_scatter.scatter_mean(x, graph_indicator, dim=0, dim_size=num)

 这里我们使用包torch_scatter来简化实现的过程,其中用到的两个函数scatter_meanscatter_max的原理如图4-6所示。


图4-6 scatter_mean 和 scatter_max 原理示意图

 至此,我们就可以定义图分类的模型了。接下来我们定义如图4-7所示的两套SADPool模型,其中 a 图禁用了一个池化层,这套模型称为 SAGPool_g ,“g” 代表 global ,如代码清单4-5的实现;b 图使用了多个池化层,这套模型称为 SAGPool_h,“h” 表示 hierarchical ,如代码清单4-6的实现。在论文的实验部分,可以发现SAGPool_g比较适合小图分类,SAGPool_h更适合大图分类。


图8-7 图分类模型


代码清单4-5 SAGPool_g模型实现

class ModelA(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes=2):
        """图分类模型结构A
        
        Args:
        ----
            input_dim: int, 输入特征的维度
            hidden_dim: int, 隐藏层单元数
            num_classes: 分类类别数 (default: 2)
        """
        super(ModelA, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
        self.gcn1 = GraphConvolution(input_dim, hidden_dim)
        self.gcn2 = GraphConvolution(hidden_dim, hidden_dim)
        self.gcn3 = GraphConvolution(hidden_dim, hidden_dim)
        self.pool = SelfAttentionPooling(hidden_dim * 3, 0.5)
        self.fc1 = nn.Linear(hidden_dim * 3 * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, num_classes)

    def forward(self, adjacency, input_feature, graph_indicator):
        gcn1 = F.relu(self.gcn1(adjacency, input_feature))
        gcn2 = F.relu(self.gcn2(adjacency, gcn1))
        gcn3 = F.relu(self.gcn3(adjacency, gcn2))
        
        gcn_feature = torch.cat((gcn1, gcn2, gcn3), dim=1)
        pool, pool_graph_indicator, pool_adjacency = self.pool(adjacency, gcn_feature,
                                                               graph_indicator)
        
        readout = torch.cat((global_avg_pool(pool, pool_graph_indicator),
                             global_max_pool(pool, pool_graph_indicator)), dim=1)
        
        fc1 = F.relu(self.fc1(readout))
        fc2 = F.relu(self.fc2(fc1))
        logits = self.fc3(fc2)
        
        return logits

 模型SAGPool_h实现如代码清单4-6所示。


代码清单4-6 SAGPool_h模型实现

class ModelB(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes=2):
        """图分类模型结构
        
        Args:
        -----
            input_dim: int, 输入特征的维度
            hidden_dim: int, 隐藏层单元数
            num_classes: int, 分类类别数 (default: 2)
        """
        super(ModelB, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
        self.gcn1 = GraphConvolution(input_dim, hidden_dim)
        self.pool1 = SelfAttentionPooling(hidden_dim, 0.5)
        self.gcn2 = GraphConvolution(hidden_dim, hidden_dim)
        self.pool2 = SelfAttentionPooling(hidden_dim, 0.5)
        self.gcn3 = GraphConvolution(hidden_dim, hidden_dim)
        self.pool3 = SelfAttentionPooling(hidden_dim, 0.5)
        
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(), 
            nn.Linear(hidden_dim // 2, num_classes))
    
    def forward(self, adjacency, input_feature, graph_indicator):
        gcn1 = F.relu(self.gcn1(adjacency, input_feature))
        pool1, pool1_graph_indicator, pool1_adjacency = \
            self.pool1(adjacency, gcn1, graph_indicator)
        global_pool1 = torch.cat(
            [global_avg_pool(pool1, pool1_graph_indicator),
             global_max_pool(pool1, pool1_graph_indicator)],
            dim=1)
        
        gcn2 = F.relu(self.gcn2(pool1_adjacency, pool1))
        pool2, pool2_graph_indicator, pool2_adjacency = \
            self.pool2(pool1_adjacency, gcn2, pool1_graph_indicator)
        global_pool2 = torch.cat(
            [global_avg_pool(pool2, pool2_graph_indicator),
             global_max_pool(pool2, pool2_graph_indicator)],
            dim=1)

        gcn3 = F.relu(self.gcn3(pool2_adjacency, pool2))
        pool3, pool3_graph_indicator, pool3_adjacency = \
            self.pool3(pool2_adjacency, gcn3, pool2_graph_indicator)
        global_pool3 = torch.cat(
            [global_avg_pool(pool3, pool3_graph_indicator),
             global_max_pool(pool3, pool3_graph_indicator)],
            dim=1)
        
        readout = global_pool1 + global_pool2 + global_pool3
        
        logits = self.mlp(readout)
        return logits

参考文献

[0] 刘忠雨, 李彦霖, 周洋.《深入浅出图神经网络: GNN原理解析》.机械工业出版社.

猜你喜欢

转载自blog.csdn.net/weixin_43360025/article/details/124628559