Pyg message passing source code (MESSAGE PASSING) + examples

1. MessagePassing base class

The key steps of GNN are message passing, aggregation, and updating.
pytorch geometric provides a MessagePassing base class, which has implemented the calculation process corresponding to the above three steps through MessagePassing.propagate(). We only need to define a class that inherits the MessagePassing base class, and then update the neighborhood aggregation method aggr="add", aggr="mean" or aggr="max" of the function message() according
to the specific graph algorithm , and the function update() , and call the progagate function in the forward function in the convolution layer of the custom graph algorithm . The general process is as follows:

import torch
from torch_geometric.nn import MessagePassing

class MyConv(MessagePassing): # 定义继承了MessagePassing基类的class
    def __init__(self, in_channels, out_channels, **kwargs):
        kwargs.setdefault('aggr', 'add')  # 邻域聚合方式
        super(MyConv, self).__init__(**kwargs)
        ...

    def forward(self, x, edge_index):
    	...
        return self.propagate(edge_index, **kwargs)

    def message(self, **kwargs):
    	...

2. Message source code

2.1 MessagePassing initialization

def __init__(self, aggr: Optional[str] = "add",
           flow: str = "source_to_target", node_dim: int = -2,
           decomposed_layers: int = 1):

aggr: Neighborhood aggregation method, the default is add, and it can also be mean, max.
flow: Message delivery direction, default source_to_target, can also be set to target_to_source.
node_dim: Define along which dimension to pass the message, the default is -2, because -1 is the feature dimension.

2.2 MessagePassing.propagate

 MessagePassing.propagate(edge_index, size=None, **kwargs)

The progagate will call the message, aggregate, and updatemethods in turn. If edge_index is SparseTensor, message_and_aggregatethe method will be called first instead of messagethe aggregatemethod.

edge_index: It has two forms Tensor and SparseTensor . The shape of edge_index in Tensor form is (2, N); SparseTensor can be understood as storing edge information in the form of a sparse matrix.
size: When size is None, the default adjacency matrix is ​​square [N, N]. If it is a heterogeneous graph (such as a bipartite graph), the features and indexes of the two types of points in the graph are independent of each other. By passing in size=(N, M), x=(x_N, x_M), propagate can handle this situation.
kwargs: Additional information required for the graph convolution calculation process can be passed in through kwargs.

2.3 MessagePassing.message()

This method calculates the message from neighbor node j to center node i under the setting of flow="source_to_target" . All parameters passed to propagate() can be passed to message() , and tensors passed to propagate() can be mapped to corresponding nodes by adding _i or _j suffix .

def message(self, x_j: Tensor) -> Tensor:
    return x_j

x_j: Represents the characteristics of neighbors, obtained by indexing the x of the corresponding position by the neighbor node in edge_index

When the shape of edge_index is (2, N_edges) and the shape of x is (N_nodes, N_features), the shape of x_j obtained is (N_edges, N_features)

Example:
edge_index:tensor([[1, 2, 3, 3], [0, 0, 0, 1]])
x:tensor([[0, 1], [2, 3], [4, 5], [ 6, 7]])
The index of neighbor node j is the first element [1,2,3,3] of edge_index, according to the index [1, 2, 3, 3] of node j, go to the position corresponding to index x, then get

x_j = x[index(j)]=x[[1,2,3,3]] = tensor([[2,3],[4,5],[6,7],[6,7]])

2.4 MessagePassing.aggregate(inputs, index, …)

This method realizes the aggregation of neighborhoods, and pytorch geometric implements three ways of mean, add, and max through scatter . Generally speaking, GCN, GraphSAGE, and GAT do not need to define an additional aggregate method for the more general graph algorithms.

2.5 MessagePassing.update(aggr_out, …)

The parameters passed to propagate before are also passed to update. Corresponding to each central node i, select the required information according to the neighborhood results of the aggregate and the parameters passed into the propagate, and update the embedding of the node i.

2.6 MessagePassing.message_and_aggregate(adj_t, …)

As mentioned earlier, the side information in pytorch geometry has two forms: Tensor and SparseTensor.
SparseTensorA matrix storage form is provided, which is stored in a sparse matrix, message_and_aggregateand a matrix calculation method for neighborhood aggregation is provided (not all graph convolutions can be calculated with a matrix).

When the edge is SparseTensorstored, propagateit will first check whether it has been implemented. message_and_aggregateIf it has been implemented, it will be called message_and_aggregateinstead of messageand aggregate. If it is not implemented, propagateyou need to convert the side information to Tensor, and then call messageand aggregate.
message_and_aggregateIt needs to be implemented by yourself, and only when it is implemented can the advantages of matrix calculation be brought into play.

3 instances

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        self.lin = torch.nn.Linear(in_channels, out_channels)
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, norm=norm)
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset/Cora', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)

Guess you like

Origin blog.csdn.net/weixin_45928096/article/details/126805227