Article directory
1. MessagePassing base class
The key steps of GNN are message passing, aggregation, and updating.
pytorch geometric provides a MessagePassing base class, which has implemented the calculation process corresponding to the above three steps through MessagePassing.propagate(). We only need to define a class that inherits the MessagePassing base class, and then update the neighborhood aggregation method aggr="add", aggr="mean" or aggr="max" of the function message() according
to the specific graph algorithm , and the function update() , and call the progagate function in the forward function in the convolution layer of the custom graph algorithm . The general process is as follows:
import torch
from torch_geometric.nn import MessagePassing
class MyConv(MessagePassing): # 定义继承了MessagePassing基类的class
def __init__(self, in_channels, out_channels, **kwargs):
kwargs.setdefault('aggr', 'add') # 邻域聚合方式
super(MyConv, self).__init__(**kwargs)
...
def forward(self, x, edge_index):
...
return self.propagate(edge_index, **kwargs)
def message(self, **kwargs):
...
2. Message source code
2.1 MessagePassing initialization
def __init__(self, aggr: Optional[str] = "add",
flow: str = "source_to_target", node_dim: int = -2,
decomposed_layers: int = 1):
aggr
: Neighborhood aggregation method, the default is add, and it can also be mean, max.
flow
: Message delivery direction, default source_to_target, can also be set to target_to_source.
node_dim
: Define along which dimension to pass the message, the default is -2, because -1 is the feature dimension.
2.2 MessagePassing.propagate
MessagePassing.propagate(edge_index, size=None, **kwargs)
The progagate will call the message
, aggregate
, and update
methods in turn. If edge_index is SparseTensor, message_and_aggregate
the method will be called first instead of message
the aggregate
method.
edge_index
: It has two forms Tensor and SparseTensor . The shape of edge_index in Tensor form is (2, N); SparseTensor can be understood as storing edge information in the form of a sparse matrix.
size
: When size is None, the default adjacency matrix is square [N, N]. If it is a heterogeneous graph (such as a bipartite graph), the features and indexes of the two types of points in the graph are independent of each other. By passing in size=(N, M), x=(x_N, x_M), propagate can handle this situation.
kwargs
: Additional information required for the graph convolution calculation process can be passed in through kwargs.
2.3 MessagePassing.message()
This method calculates the message from neighbor node j to center node i under the setting of flow="source_to_target" . All parameters passed to propagate() can be passed to message() , and tensors passed to propagate() can be mapped to corresponding nodes by adding _i or _j suffix .
def message(self, x_j: Tensor) -> Tensor:
return x_j
x_j
: Represents the characteristics of neighbors, obtained by indexing the x of the corresponding position by the neighbor node in edge_index
When the shape of edge_index is (2, N_edges) and the shape of x is (N_nodes, N_features), the shape of x_j obtained is (N_edges, N_features)
Example:
edge_index
:tensor([[1, 2, 3, 3], [0, 0, 0, 1]])
x
:tensor([[0, 1], [2, 3], [4, 5], [ 6, 7]])
The index of neighbor node j is the first element [1,2,3,3] of edge_index, according to the index [1, 2, 3, 3] of node j, go to the position corresponding to index x, then get
x_j = x[index(j)]=x[[1,2,3,3]] = tensor([[2,3],[4,5],[6,7],[6,7]])
2.4 MessagePassing.aggregate(inputs, index, …)
This method realizes the aggregation of neighborhoods, and pytorch geometric implements three ways of mean, add, and max through scatter . Generally speaking, GCN, GraphSAGE, and GAT do not need to define an additional aggregate method for the more general graph algorithms.
2.5 MessagePassing.update(aggr_out, …)
The parameters passed to propagate before are also passed to update. Corresponding to each central node i, select the required information according to the neighborhood results of the aggregate and the parameters passed into the propagate, and update the embedding of the node i.
2.6 MessagePassing.message_and_aggregate(adj_t, …)
As mentioned earlier, the side information in pytorch geometry has two forms: Tensor and SparseTensor.
SparseTensor
A matrix storage form is provided, which is stored in a sparse matrix, message_and_aggregate
and a matrix calculation method for neighborhood aggregation is provided (not all graph convolutions can be calculated with a matrix).
When the edge is SparseTensor
stored, propagate
it will first check whether it has been implemented. message_and_aggregate
If it has been implemented, it will be called message_and_aggregate
instead of message
and aggregate
. If it is not implemented, propagate
you need to convert the side information to Tensor, and then call message
and aggregate
.
message_and_aggregate
It needs to be implemented by yourself, and only when it is implemented can the advantages of matrix calculation be brought into play.
3 instances
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset/Cora', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)