introduction
Graph representation learning requires the attributes of the input nodes and edges to obtain a vector as a representation of the graph. Based on the graph representation, we can further predict the graph.
Graph Isomorphic Networks paper: How Powerful are Graph Neural Networks?
Implementation of Graph Representation Network Based on Graph Isomorphic Network (GIN)
Graph representation learning based on graph isomorphic network mainly includes the following two processes :
- First calculate the node representation;
- Secondly, Graph Pooling (Graph Pooling), or Graph Readout, is performed on the representation of each node on the graph to obtain a graph representation (Graph Representation).
Here, we will adopt a top-down approach to learn a graph representation learning method based on the graph isomorphism model (GIN) . We first focus on how to compute graph representations based on node representations, ignoring the method of computing node representations .
Graph representation module based on graph isomorphic network (GINGGraphRepr Module)
Graph representation module:
- Node embedding is performed on each node on the graph to obtain a node representation.
- Graph pooling is performed on node representations to obtain graph representations.
- Convert graph representations to graph predictions with a layer of linear transformations.
import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_regression.gin_node import GINNodeEmbedding
#%%
class GINGraphPooling(nn.Module):
def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
"""此模块首先采用GINNodeEmbedding模块对图上每一个节点做嵌入,然后对节点嵌入做池化得到图的嵌入,最后用一层线性变换得到图的最终的表示(graph representation)
Args:
num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表征的维度,dimension of graph representation).
num_layers (int, optional): number of GINConv layers. Defaults to 5.
emb_dim (int, optional): dimension of node embedding. Defaults to 300.
residual (bool, optional): adding residual connection or not. Defaults to False.
drop_ratio (float, optional): dropout rate. Defaults to 0.
JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".
Out:
graph representation
"""
super(GINGraphPooling, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
# 对图上的每个节点进行节点嵌入
self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
# Pooling function to generate whole-graph embeddings
if graph_pooling == "sum":
self.pool = global_add_pool
elif graph_pooling == "mean":
self.pool = global_mean_pool
elif graph_pooling == "max":
self.pool = global_max_pool
elif graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn=nn.Sequential(
nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
elif graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, batched_data):
h_node = self.gnn_node(batched_data)
h_graph = self.pool(h_node, batched_data.batch)
output = self.graph_pred_linear(h_graph)
if self.training:
return output
else:
# At inference time, relu is applied to output to ensure positivity
# 因为预测目标的取值范围就在 (0, 50] 内
return torch.clamp(output, min=0, max=50)
It can be seen that the optional methods for calculating graph representations based on node representations are:
- "sum": sums node representations;
- "mean": - average node representation;
- "max": Take the maximum value of the node representation, and calculate the maximum value of each dimension of the node representation for all nodes in a batch;
- "attention": weighted sum of node representations based on Attention;
- “set2set”:
- Another method of weighting and summing node representations based on Attention;
- Use the module torch_geometric.nn.glob.Set2Set ;
- From the paper "Order Matters: Sequence to sequence for sets" .
Next we will learn about node embedding methods.
Node Embedding Module Based on Graph Isomorphic Network (GINNodeEmbedding Module)
Node Embedding Module (GINNodeEmbeddingModule):
- Embedding with AtomEcoder to get the 0th layer node representation
- Calculate node representation layer by layer
- The larger the receptive field, the representation of node i can finally capture the information of the adjacent nodes whose distance from node i is num_layers
import torch
from gin_regression.mol_encoder import AtomEncoder
from gin_regression.gin_conv import GINConv
import torch.nn.functional as F
# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
"""GIN Node Embedding Module"""
super(GINNodeEmbedding, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
# add residual connection or not
self.residual = residual
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
# List of GNNs
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(GINConv(emb_dim))
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, batched_data):
x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr
# computing input node embedding
h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子表征
for layer in range(self.num_layers):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
# Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
return node_representation
Next, let's learn the key components of the graph isomorphic network GINConv
.
GINConv
– Graph isomorphic convolutional layer
Let us also have the following infinitive:
xi ′ = h Θ ( ( 1 + ϵ ) ⋅ xi + ∑ j ∈ N ( i ) xj ) \mathbf{x}^{\prime}_i = h_{\mathbf {\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)xi′=hTh⎝⎛(1+) _⋅xi+j∈N(i)∑xj⎠⎞
This module has been implemented in PyG, and we can torch_geometric.nn.GINConv
use the graph isomorphic convolution layer defined by PyG. However, this implementation does not support graphs with edge attributes . Here we customize a module that supports edge attributesGINConv
.
Since the input edge attributes are category-type, we need to convert the category-type edge attributes into edge representations first. GINConv
The modules we define follow the process of "message delivery, message aggregation, message update" .
- This process
self.propagate
begins with the call of , which receivesedge_index
,x
,edge_attr
these three functions.edge_index
is2,num_edges
a tensor of shape . - During message passing, this tensor is first split by row into sum
x_i
tensorsx_j
,x_j
which represent the source node of the message transfer, andx_i
represent the target node of the message transfer. - Then
message
the function is called, which defines the message passed from the source node to the target node, where the message to be transmitted is the sum of the source node representation and the edge representationrelu
. Wesuper(GINConv, self).__init__(aggr = "add")
define the message aggregation method inadd
, then all the messages passed to any target node are summedaggr_out
, which is the information of the intermediate process of the target node. - Then the message update process is executed, our class
GINConv
inheritsMessagePassing
the class, soupdate
the function is called. However, we want to add the target node's own message in the message update of the node, so inupdate
the function we simply return the inputaggr_out
. - Then in
forward
the function we performout = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
the update of the implementation message.
import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder
### GIN convolution along the graph structure
class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr = "add")
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
self.eps = nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
Weisfeiler-Lehman test (WL test)
Graph isomorphism test
- Iterate over the labels of aggregated nodes and their neighbors
- Hash aggregate tags into new tags, mathematical formula:
L uh ← hash ( L uh − 1 + ∑ v ∈ N ( U ) L vh − 1 ) L^{h}_{u} \leftarrow \operatorname{ hash}\left(L^{h-1}_{u} + \sum_{v \in \mathcal{N}(U)} L^{h-1}_{v}\right)Luh←hash⎝⎛Luh−1+v∈N(U)∑Lvh−1⎠⎞ - The WL subtree kernel measures the similarity between graphs: using the node label counts in different iterations as the representation vector of the graph
- detailed steps:
- Aggregate the labels of itself and adjacent nodes to get a string of strings
- Tag hashing, mapping a longer string to a shorter tag
- Relabel the node
- Graph similarity evaluation:
- WL Subtree Kernel method: use the WL Test algorithm to obtain multi-layer labels of nodes, count the number of occurrences of various labels in the graph, and use vector representation as a representation of the graph
- The inner product of the representation vectors of the two graphs is used as the similarity estimation of the two graphs
- Necessary conditions for judging graph isomorphism: two nodes have the same label and their adjacent nodes are the same, and the two nodes are mapped to the same representation
epilogue
In this article, we learned the graph representation network based on the graph isomorphic network (GIN). In order to obtain the graph representation, we first need to do node representation, and then do graph readout. The calculation of node representation in GIN follows the update method of node labels in WL Test algorithm, so its upper bound is WL Test algorithm. In graph reading, we sum all node representations (weighted, if Attention is used), which will cause the loss of node distribution information.
References
-
Proposed Global Attention paper: "Gated Graph Sequence Neural Networks"
-
The paper that proposed Set2Set: "Order Matters: Sequence to sequence for sets"
-
All graph pooling methods integrated in PyG: Global Pooling Layers
-
Weisfeiler-Lehman Test: Brendan L Douglas. The weisfeiler-lehman method and graph isomorphism testing. arXiv preprint arXiv:1101.5211, 2011.