Task6 Graph Representation Learning Method Based on Graph Neural Network

introduction

Graph representation learning requires the attributes of the input nodes and edges to obtain a vector as a representation of the graph. Based on the graph representation, we can further predict the graph.

Graph Isomorphic Networks paper: How Powerful are Graph Neural Networks?

Implementation of Graph Representation Network Based on Graph Isomorphic Network (GIN)

Graph representation learning based on graph isomorphic network mainly includes the following two processes :

  1. First calculate the node representation;
  2. Secondly, Graph Pooling (Graph Pooling), or Graph Readout, is performed on the representation of each node on the graph to obtain a graph representation (Graph Representation).

Here, we will adopt a top-down approach to learn a graph representation learning method based on the graph isomorphism model (GIN) . We first focus on how to compute graph representations based on node representations, ignoring the method of computing node representations .

Graph representation module based on graph isomorphic network (GINGGraphRepr Module)

Graph representation module:

  • Node embedding is performed on each node on the graph to obtain a node representation.
  • Graph pooling is performed on node representations to obtain graph representations.
  • Convert graph representations to graph predictions with a layer of linear transformations.
import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_regression.gin_node import GINNodeEmbedding
#%%
class GINGraphPooling(nn.Module):

    def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
        """此模块首先采用GINNodeEmbedding模块对图上每一个节点做嵌入,然后对节点嵌入做池化得到图的嵌入,最后用一层线性变换得到图的最终的表示(graph representation)
        Args:
            num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表征的维度,dimension of graph representation).
            num_layers (int, optional): number of GINConv layers. Defaults to 5.
            emb_dim (int, optional): dimension of node embedding. Defaults to 300.
            residual (bool, optional): adding residual connection or not. Defaults to False.
            drop_ratio (float, optional): dropout rate. Defaults to 0.
            JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
            graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".

        Out:
            graph representation
        """
        super(GINGraphPooling, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")
        # 对图上的每个节点进行节点嵌入
        self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
        # Pooling function to generate whole-graph embeddings
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=nn.Sequential(
                nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
        elif graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        h_node = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batched_data.batch)
        output = self.graph_pred_linear(h_graph)

        if self.training:
            return output
        else:
            # At inference time, relu is applied to output to ensure positivity
            # 因为预测目标的取值范围就在 (0, 50] 内
            return torch.clamp(output, min=0, max=50)

It can be seen that the optional methods for calculating graph representations based on node representations are:

  • "sum": sums node representations;
  • "mean": - average node representation;
  • "max": Take the maximum value of the node representation, and calculate the maximum value of each dimension of the node representation for all nodes in a batch;
  • "attention": weighted sum of node representations based on Attention;
  • “set2set”:
    1. Another method of weighting and summing node representations based on Attention;
    2. Use the module torch_geometric.nn.glob.Set2Set ;
    3. From the paper "Order Matters: Sequence to sequence for sets" .
      Next we will learn about node embedding methods.

Node Embedding Module Based on Graph Isomorphic Network (GINNodeEmbedding Module)

Node Embedding Module (GINNodeEmbeddingModule):

  • Embedding with AtomEcoder to get the 0th layer node representation
  • Calculate node representation layer by layer
  • The larger the receptive field, the representation of node i can finally capture the information of the adjacent nodes whose distance from node i is num_layers
import torch
from gin_regression.mol_encoder import AtomEncoder
from gin_regression.gin_conv import GINConv
import torch.nn.functional as F

# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
    """
    Output:
        node representations
    """

    def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
        """GIN Node Embedding Module"""

        super(GINNodeEmbedding, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        # add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        # List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(num_layers):
            self.convs.append(GINConv(emb_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr

        # computing input node embedding
        h_list = [self.atom_encoder(x)]  # 先将类别型原子属性转化为原子表征
        for layer in range(self.num_layers):
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        # Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers + 1):
                node_representation += h_list[layer]

        return node_representation

Next, let's learn the key components of the graph isomorphic network GINConv.

GINConv– Graph isomorphic convolutional layer

Let us also have the following infinitive:
xi ′ = h Θ ( ( 1 + ϵ ) ⋅ xi + ∑ j ∈ N ( i ) xj ) \mathbf{x}^{\prime}_i = h_{\mathbf {\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)xi=hTh(1+) _xi+jN(i)xj
This module has been implemented in PyG, and we can torch_geometric.nn.GINConvuse the graph isomorphic convolution layer defined by PyG. However, this implementation does not support graphs with edge attributes . Here we customize a module that supports edge attributesGINConv .

Since the input edge attributes are category-type, we need to convert the category-type edge attributes into edge representations first. GINConvThe modules we define follow the process of "message delivery, message aggregation, message update" .

  • This process self.propagatebegins with the call of , which receives edge_index, x, edge_attrthese three functions. edge_indexis 2,num_edgesa tensor of shape .
  • During message passing, this tensor is first split by row into sum x_itensors x_j, x_jwhich represent the source node of the message transfer, and x_irepresent the target node of the message transfer.
  • Then messagethe function is called, which defines the message passed from the source node to the target node, where the message to be transmitted is the sum of the source node representation and the edge representation relu. We super(GINConv, self).__init__(aggr = "add")define the message aggregation method in add, then all the messages passed to any target node are summed aggr_out, which is the information of the intermediate process of the target node.
  • Then the message update process is executed, our class GINConvinherits MessagePassingthe class, so updatethe function is called. However, we want to add the target node's own message in the message update of the node, so in updatethe function we simply return the input aggr_out.
  • Then in forwardthe function we perform out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))the update of the implementation message.
import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder


### GIN convolution along the graph structure
class GINConv(MessagePassing):
    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''
        super(GINConv, self).__init__(aggr = "add")

        self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
        self.eps = nn.Parameter(torch.Tensor([0]))
        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)
        
    def update(self, aggr_out):
        return aggr_out

Weisfeiler-Lehman test (WL test)

Graph isomorphism test

  1. Iterate over the labels of aggregated nodes and their neighbors
  2. Hash aggregate tags into new tags, mathematical formula:
    L uh ← hash ⁡ ( L uh − 1 + ∑ v ∈ N ( U ) L vh − 1 ) L^{h}_{u} \leftarrow \operatorname{ hash}\left(L^{h-1}_{u} + \sum_{v \in \mathcal{N}(U)} L^{h-1}_{v}\right)LuhhashLuh1+vN(U)Lvh1
  3. The WL subtree kernel measures the similarity between graphs: using the node label counts in different iterations as the representation vector of the graph
  4. detailed steps:
    1. Aggregate the labels of itself and adjacent nodes to get a string of strings
    2. Tag hashing, mapping a longer string to a shorter tag
    3. Relabel the node
  5. Graph similarity evaluation:
    1. WL Subtree Kernel method: use the WL Test algorithm to obtain multi-layer labels of nodes, count the number of occurrences of various labels in the graph, and use vector representation as a representation of the graph
    2. The inner product of the representation vectors of the two graphs is used as the similarity estimation of the two graphs
  6. Necessary conditions for judging graph isomorphism: two nodes have the same label and their adjacent nodes are the same, and the two nodes are mapped to the same representation

epilogue

In this article, we learned the graph representation network based on the graph isomorphic network (GIN). In order to obtain the graph representation, we first need to do node representation, and then do graph readout. The calculation of node representation in GIN follows the update method of node labels in WL Test algorithm, so its upper bound is WL Test algorithm. In graph reading, we sum all node representations (weighted, if Attention is used), which will cause the loss of node distribution information.

References

Guess you like

Origin blog.csdn.net/weixin_44133327/article/details/118502767