Torch geometric NNConv 源码分析

公式

x i = Θ x i + j N ( i ) x j h Θ ( e i , j ) , \mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),
其中, x i \mathbf{x}_i 是节点i的特征, Θ \mathbf{\Theta} 是待学习的参数矩阵, N ( i ) \mathcal{N}(i) 是节点i的所有邻接点, h Θ h_{\mathbf{\Theta}} 是一个神经网络,例如多层感知机, e i , j \mathbf{e}_{i,j} 是节点i和节点j的边特征。公式中的乘法均为矩阵乘法。这个公式的含义,就是对于节点i而言,新的节点特征由其原先的节点特征以及其邻节点的节点特征和与之相连的边特征得到。

NNConv源码

类及其参数:
class NNConv(in_channels, out_channels, nn, aggr='add', root_weight=True, bias=True, **kwargs)

  • in_channels (int) – Size of each input sample.特征的输入维度,一般是节点的隐藏状态维度
  • out_channels (int) – Size of each output sample.特征的输出维度,一般是节点的隐藏状态维度
  • nn (torch.nn.Module) – A neural network hΘ that maps edge features edge_attr of shape [-1, num_edge_features] to shape [-1, in_channels * out_channels], e.g., defined by torch.nn.Sequential.一个映射边特征的神经网络,形状从[-1, num_edge_features]映射到[-1, in_channels*out_channels
  • aggr (string, optional) – The aggregation scheme to use (“add”, “mean”, “max”). (default: “add”)聚合方法,默认是加法(就是公式中的累加符号,如果是mean就是对周围节点的信息取平均值)
  • root_weight (bool, optional) – If set to False, the layer will not add the transformed root node features to the output. (default: True)如果设置为False,在更新的时候不会把节点本身的特征加上,即上面的数学公式里第一项为0,一般默认True就好了。
  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)偏置,即在上面的数学公式最后加一个常数项
  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

如果你的用途是使用一层NNConv进行多次迭代的话(Neural Message Passing for Quantum Chemistry论文的思想就是这个),由于 Θ \mathbf{\Theta} 的尺寸是固定的,所以in_channelout_channel是相等的。如果不相等的话,假设in_channel=nout_channel=m,那在第一次迭代的时候 Θ \mathbf{\Theta} 的尺寸为(m, n),第一次迭代后新的节点特征维度是m,但是在第二次迭代的时候尺寸就会不匹配。

但是如果你的用途是构造多层NNConv的话,那么in_channelout_channel可以不相等。

另外要说明的是,参数nn是一个普通的全连接神经网络层,输出的维度是in_channels*out_channels,这也是为了和节点的特征维度匹配。一般做完这个神经网络层后会view(或者是reshape)一下,把维度变成(in_channels, out_channels),这样就可以与节点的特征做运算了。

了解完上面的概念以后,分析代码就比较简单了。

import torch
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing

from ..inits import reset, uniform


class NNConv(MessagePassing):

    def __init__(self,
                 in_channels,
                 out_channels,
                 nn,
                 aggr='add',
                 root_weight=True,
                 bias=True,
                 **kwargs):
        super(NNConv, self).__init__(aggr=aggr, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.nn = nn
        self.aggr = aggr

        if root_weight:
            self.root = Parameter(torch.Tensor(in_channels, out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)
        uniform(self.in_channels, self.root)
        uniform(self.in_channels, self.bias)

    def forward(self, x, edge_index, edge_attr):
        """"""
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr
        return self.propagate(edge_index, x=x, pseudo=pseudo)

    def message(self, x_j, pseudo):
        weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels)
        return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)

    def update(self, aggr_out, x):
        if self.root is not None:
            aggr_out = aggr_out + torch.mm(x, self.root)
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

初始化

__init__()reset_parameters()就是保存一下参数,初始化参数,这个没啥好说的。主要是提一下nn这个参数,它是一个全连接的神经网络层。具体的前向传播,主要是下面3个函数构成。

前向传播

首先,我们可以看到NNConv类继承MessagePassing类,而MessagePassing又继承torch.nn.Module类。关于MessagePassing这个类,我们只需要知道它提供了一个方法self.propagate(),在这个方法里面会依次调用messageupdate函数,其他的以后再说。

forward

    def forward(self, x, edge_index, edge_attr):
        """"""
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr
        return self.propagate(edge_index, x=x, pseudo=pseudo)

在把数据传入到NNConv的时候,会自动调用forward()函数。现在假设我们已经把数据传进来了,x是节点特征矩阵,形状为(num_nodes, node_features)edge_index是COO格式的邻接矩阵,其形状为(2, num_edges)(如果这里不懂这个邻接矩阵是怎么样的可以看后面的demo),edge_attr是边的特征矩阵,与edge_index一 一对应,其形状为(num_edges, edge_features)

这里forward主要是进行了预处理。首先是保证x是矩阵,因为有可能节点的维度只有1传进来是一个一维数组。类似地,也要保证edge_attr是矩阵。最后调用了propagate函数,但是正如我前面所说,我们只需要知道在propagate里面依次调用了messageupdate函数即可。

message

    def message(self, x_j, pseudo):
        weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels)
        return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)

这里的输入参数x_j是邻接点的特征,pesudo是边的特征矩阵。我简化一下思路,假设x_j是某个节点i的邻接节点的特征向量,他们之间的边特征用pesudo向量表示。首先把pesudo向量放到全连接神经网络中,再view一下,最后把形状为(1, in_channels)(in_channels, out_channels)的两个矩阵相乘,就得到了形状为(1, out_channels)的新特征向量。
如果你对比公式,会发现这个函数就是在做累加符号里的事情。累加的过程在propagate里完成了。

update

    def update(self, aggr_out, x):
        if self.root is not None:
            aggr_out = aggr_out + torch.mm(x, self.root)
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

最后的update阶段就更简单了,把公式中第一项和第二项相加就行了,如果有偏置的话就再加上偏置。之前提到累加的过程在propagate里完成了,这一结果被当做参数aggr_out传了进来。

demo

这里会举一个实际的例子并辅以手动计算来验证。demo使用pytorch geometric样例的图形,如下图所示
在这里插入图片描述

图片来自https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html

首先定义一张图。为了更好地理解,图中我把节点的特征值改成2维的了(而不是上面一样每个节点只有1维特征值),另外对边(0,1)和(1,2)也定义了5维特征值,详见代码。

import torch
import torch.nn as nn
from torch_geometric.nn import NNConv


# 随机种子
torch.manual_seed(0)

# 定义维度
node_input_dim = 2
edge_input_dim = 5
edge_hidden_dim = 4 # 这个随意

# 定义边
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

# 定义节点特征值
x = torch.tensor([[-1,2], [0,4], [1,5]], dtype=torch.float)

# 定义边的特征,考虑无向图,同一条边特征一样
edge_attr = torch.tensor([[1,2,3,5,5],
                          [1,2,3,4,5],
                          [5,4,3,2,1],
                          [5,4,3,2,1]], dtype=torch.float)


# 定义一个全连接神经网络
edge_network = nn.Sequential(
    nn.Linear(edge_input_dim, edge_hidden_dim),
    nn.ReLU(),
    nn.Linear(edge_hidden_dim, node_input_dim*node_input_dim)
)

conv = NNConv(node_input_dim, node_input_dim, edge_network)
x = conv(x, edge_index, edge_attr)
print(x)
# tensor([[-4.9210,  3.9953],
#         [-7.7736,  7.6713],
#         [-3.0483,  5.6156]], grad_fn=<AddBackward0>)

这里我强调一下edge_index这个变量的表示,edge_index[0]表示边的起点列表,edge_index[1]表示边的终点列表。因此edge_index[0][0]edge_index[1][0]共同表示一条边。

手动计算

message部分开始。x_j保存了邻接点的特征,具体而言,x_j[0]=[-1,2],这表示第一条边的起点特征是[-1,2]。因此,更为准确地说,x_j保存了所有边的终点的邻接点的特征。再换句话说,x_j[i]表示第i+1条边的起点特征。
把边的特征矩阵放到全连接神经网络后再view一下,输出结果为

[
[[ 1.4935, 1.3571],
[-1.4184, 0.6492]],
[[ 1.4935, 1.3571],
[-1.4184, 0.6492]],
[[ 1.4168, 1.0576],
[-1.2806, 0.9955]],
[[ 1.4168, 1.0576],
[-1.2806, 0.9955]]
]

形状为(4, 2, 2),这里有4个2×2的矩阵,对应四条无向边。将对应的x_j的分量与这4个矩阵相乘,得到一个形状为(4, 2)的矩阵,表示每条边和它的起点对它的终点的共同作用。

[[-4.3304, -0.0586],
[-5.6737, 2.5969],
[-5.1225, 3.9821],
[-4.9863, 6.0351]]

至此,已经计算出边及其起点的共同作用,据此我们可以计算每个节点得到的信息,具体的实现方式是根据边的终点进行汇总。例如,以0号节点作为终点的边只有一条,为(1,0),因此我从上面找到这条边的权重,即[-5.6737, 2.5969]。而以1号节点作为终点的边有两条,分别是(0,1)(2,1),因此我从上面找到这两条边的权重并相加(第1条边和第4条边),得到 [-9.3167, 5.9765],其他节点也按照此计算。最终这一步的输出为

[[-5.6737, 2.5969],
[-9.3167, 5.9765],
[-5.1225, 3.9821]]

至此已经完成了求和符号里的所有过程,公式中的第一项的计算比较容易,就不计算了(主要是我懒得把数值打上来)。

发布了19 篇原创文章 · 获赞 29 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/qq_41987033/article/details/103497749
今日推荐