The basic theory of GraphSAGE

GraphSAGE principle (for understanding)

Introduce:

Disadvantages of GCN:

  • Difficulties in learning from large networks : GCNs require the presence of all nodes during embedding training. This does not allow batch training of models.
  • The difficulty of extending to unseen nodes : GCN assumes a single fixed graph and requires learning the embedding of vertices in a certain graph. However, in many practical applications, there is a need to quickly generate embeddings of unseen nodes. However, GCN cannot directly generalize to vertices that have not appeared in the training process, which belongs to a kind of transductive learning.

GraphSAGE: The core idea is to generate the embedding vector of the target vertex by learning a function that aggregates the neighbor vertices .

GraphSAGE workflow

  1. Sample the neighbor vertices of each vertex in the graph. Instead of using the entire neighborhood of a given node, the model uniformly samples a fixed-size set of neighbors.
  2. Aggregate the information contained in the neighbor vertices according to the aggregation function.
  3. Get a vector representation of each vertex in the graph for use in downstream tasks.
  1. Sample the neighbor vertices of each vertex in the graph

    • In consideration of computational efficiency : a certain number of neighbor vertices are sampled for each vertex as vertices to be aggregated. Let the number of samples be k, if the number of vertex neighbors is less than k, then use the sampling method with replacement until k vertices are sampled. If the number of vertex neighbors is greater than k, sampling without replacement is used.
    • Without considering the calculation efficiency : it is completely possible to aggregate information for each vertex using all its neighbor vertices, which is information lossless.

    Specifically:

    at kkAt layer k , for each vertex vvv , first usethe vertex vvk − 1 k-1 of v 's neighbor verticeskThe embedding representation of layer 1 is used to generate the kkthneighbor verticesThe aggregation of k layers representsh N ( v ) k h_{N(v)}^khN(v)k, then h N ( v ) k h_{N(v)}^khN(v)kand vertex vvv 'sk − 1 k-1k1 layer representshvk − 1 h_v^{k-1}hvk1Concatenate (concat), generate a vertex vv after a nonlinear transformationv 'skkthk- layer embedding representation. Generally speaking, the available formula can be expressed as (just to illustrate the process, it depends on the selection of aggregation function): h N ( v ) k = aggregatek ( { huk − 1 , ∀ u ∈ N ( v ) } ) h_{ N(v)}^k = aggregate_k(\{h_u^{k-1}, \forall u \in N(v)\})hN(v)k=aggregatek({ huk1,uN(v)}) h v k = σ ( W k ⋅ c o n c a t ( h v k − 1 , h N ( v ) k ) ) h_v^k = \sigma (W^k \cdot concat(h_v^{k-1},h_{N(v)}^k) ) hvk=s ( Wkconcat(hvk1,hN(v)k))

  2. Aggregate the information contained in the neighbor vertices according to the aggregation function

    Selection of aggregate functions:

    • MEAN aggregator:
      h v k = σ ( W ⋅ M E A N ( { h v k − 1 } ∪ { h u k − 1 , ∀ u ∈ N ( v ) } ) ) h_v^k = \sigma (W \cdot MEAN(\{h_v^{k-1}\}\cup\{h_u^{k-1},\forall u \in N(v)\})) hvk=s ( WMEAN({ hvk1}{ huk1,uN(v)}))
    • Pooling aggregator:
      h N ( v ) k = p o o l i n g _ m e t h o d ( { σ ( W h u k + b ) , ∀ u ∈ N ( v ) } ) h_{N(v)}^k = pooling\_method(\{\sigma(Wh_u^k+b), \forall u \in N(v)\}) hN(v)k=pooling_method({ σ(Whuk+b),uN ( v )}) Note here is h N ( v ) k , not hvk; where pooling _ method ∈ { max , mean } Note here is h_{N(v)}^k, not h_v^k; where pooling\_method \in \{max,mean\}Note here is hN(v)k, instead of hvk; where p oo l in g _ m e t h o d{ max,mean}
    • LSTM aggregator: Randomly
      scramble the neighbor nodes of the central node as the input sequence, and represent the resulting vector h N ( v ) k h_{N(v)}^khN(v)kAnd the vector hvk − 1 h_v^{k-1} of the central nodehvk1The representations are concatenated after nonlinear transformation, and the vector representation of the central node at this layer is obtained. LSTM itself is used for sequence data, and neighbor nodes have no obvious sequence relationship, so the neighbor nodes input into LSTM need to randomly shuffle the order
  3. parameter learning

    • unsupervised learning

      Graph-based loss functions want adjacent vertices to have similar vector representations, while separating the representations of vertices as differently as possible. The objective function is as follows: JG ( zu ) = − log ( σ ( zu T zv ) ) − Q ⋅ E vn ∼ P n ( v ) log ( σ ( − zu T zvn ) ) J_G(z_u) = -log(\sigma (z_u^Tz_v))-Q\cdot E_{v_n ∼ P_n(v) }log(\sigma(-z_u^Tz_{v_n}))JG(zu)=l o g ( σ ( zuTzv))QEvnPn(v)l o g ( σ ( zuTzvn)) wherevvv is in the fixed-length random walk atuuNodes that appear simultaneously near u , σ σσ isthe sigmoid function sigmoid functions i g m o i d function ,P n P_nPnis a negative sampling distribution, QQQ defines the number of negative samples.

      Different from DeepWalk, the vertex representation vector here is generated by aggregating the adjacent point features of the vertices, rather than simply performing an embedding lookup operation.

    • supervised learning

      The supervised learning form can directly set the objective function according to different tasks. For example, the most commonly used node classification task uses the cross-entropy loss function.

Practical basic theory of GraphSAGE (for coding)

1. The underlying implementation of GraphSAGE (pytorch)

NeighorSampler in PyG realizes mini-batch + GraphSAGE sample of node dimension

Reference blog: https://blog.csdn.net/weixin_39925939/article/details/121458145

SAGEConv implementation in PyG

The implementation method in the original paper: xi ′ = W ⋅ concat ( A aggregatej ∈ N ( i ) xj , xi ) x_i' = W \cdot concat(Aggregate_{j\in N(i)}x_j,x_i)xi=Wconcat(AggregatejN(i)xj,xi)

PyG中实现方法: x i ′ = W 1 x i + W 2 ⋅ A g g r e g a t e j ∈ N ( i ) x j x_i' = W_1x_i + W_2 \cdot Aggregate_{j\in N(i)}x_j xi=W1xi+W2AggregatejN(i)xj

These two methods are the same, but the difference is:

The neighbors in the SAGEConv code are the neighbors you pass in. Whether you use NeighborSampler or other methods to sample the neighbors or all the neighbors that have not been sampled, it only accepts the neighbors you pass in, and neighbor sampling is not implemented here.

  • init function

    Parameter Description:

    • in_channels: Union[int, Tuple[int, int]]: Input the dimension of the original feature or hidden layer embedding. If -1, the feature dimension is inferred from the passed x. Note that in_channels can be an integer or a tuple composed of two integers, corresponding to the feature dimensions of the source node and target node.
    • source node: the neighbor node of the central node. { xj , ∀ j ∈ N ( i ) } \{x_j, \forall j\in N(i)\}{ xj,jN(i)}
    • target node: central node. xi x_ixi
    • in_channels[0]: parameter W 2 W_2W2The shape[0], the feature matrix after point multiplication of the source node aggregation
    • in_channels[1]: parameter W 1 W_1W1The shape[0], point multiplication corresponding to the feature matrix of the target node
    • out_channels: Output dimension of embedding
    • normalize: Whether to perform l 2 l_2 on the outputl2normalized, defaults toFalse
    • bias: bias, defaults toTrue
    • root_weight: Whether the output will add the value of the converted dimension of the node's own features, the default is True.
    • kwargs.setdefault('aggr', 'mean'): Neighborhood aggregation method, default aggr='mean', other methods are available aggr='max',aggr='add'
    class SAGEConv(MessagePassing):    
        def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, normalize: bool = False,
                 root_weight: bool = True,
                 bias: bool = True, **kwargs):  # yapf: disable
        kwargs.setdefault('aggr', 'mean')
        super(SAGEConv, self).__init__(**kwargs)
    
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight
    
        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)
    
        self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
        if self.root_weight:
            self.lin_r = Linear(in_channels[1], out_channels, bias=False)
    
        self.reset_parameters()
    
    
  • forward function

    Parameter Description:

    • x:Union[Tensor, OptPairTensor]: Can be Tensor, or OptPairTensor(tuple of Tensor defined by pyg).

    When the graph is bipartite, x is OptPairTensor, in order to correspond to the definition in the init function in_channel, to make:

    • sourceNode (neighbor node) feature correspondence x[0], assigned to in the code x_l, in_channel[0]( W 2 W_2W2)defined aslin_l
    • targetThe node (central node) feature corresponds to x[1], and is assigned to in the code x_r, in_channel[1]( W 1 W_1W1)defined aslin_r
    • edge_index: Adj: Adj is the adjacency matrix type defined by pyg, which can be Tensor or SparseTensor.
    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                   size: Size = None) -> Tensor:
           """"""
           if isinstance(x, Tensor):
               x: OptPairTensor = (x, x)
    
           # propagate_type: (x: OptPairTensor)
           out = self.propagate(edge_index, x=x, size=size)
           out = self.lin_l(out)
    
           x_r = x[1]
           if self.root_weight and x_r is not None:
               out += self.lin_r(x_r)
    
           if self.normalize:
               out = F.normalize(out, p=2., dim=-1)
    
           return out
    
    
  • Message passing function (message function)

    Called when used in the forward function self.propagate, the passed in edge_indexis not an explicit parameter.

    Parameter Description:

    • When edge_index is Tensor
      and edge_index is Tensor, propagate calls message and aggregate to implement message delivery and update. Here the message function does not process the neighbor features, but just passes them on, so the final propagate function just aggregates the neighbor features.
    • edge_index is SparseTensor
      When edge_index is SparseTensor, the propagate function will be called first when message_and_aggregate is defined, instead of message and aggregate.
      Here message_and_aggregate directly calls similar matrix calculations matmul(adj_t, x[0], reduce=self.aggr). x[0] is the feature of the source node. matmul comes from torch_sparse. In addition to similar conventional matrix multiplication, it also provides optional reduce, where add, mean and max aggregation can be realized.
    def message(self, x_j: Tensor) -> Tensor:
        return x_j
    
    def message_and_aggregate(self, adj_t: SparseTensor,
                              x: OptPairTensor) -> Tensor:
        adj_t = adj_t.set_value(None, layout=None)
        return matmul(adj_t, x[0], reduce=self.aggr)
    

2. GraphSAGE instance

import torch
import torch.nn.functional as F
from torch_geometric.nn.conv import SAGEConv

class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.):
        super(SAGE, self).__init__()
        
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        
        self.dropout = dropout
        
    def reset_parameters():
        for conv in self.convs:
            conv.reset_parameters()
            
    def forward(self, x, edge_index):
        x = self.convs[0](x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[1](x, edge_index)
        
        return x.log_softmax(dim=-1)

#读取数据
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

transform = T.ToSparseTensor()
# 这里加上了ToSparseTensor(),所以边信息是以adj_t形式存储的,如果没有这个变换,则是edge_index
dataset = Planetoid(name='Cora', root=r'./dataset/Cora', transform=transform)
data = dataset[0]
data.adj_t = data.adj_t.to_symmetric()

model = SAGE(in_channels=dataset.num_features, hidden_channels=128, out_channels=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train():
    model.train()
    
    optimizer.zero_grad()
    out = model(data.x, data.adj_t)[data.train_mask] #前面我们提到了,SAGE是实现了edge_index和adj_t两种形式的
    loss = F.nll_loss(out, data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    return loss.item()

@torch.no_grad()
def test():
    model.eval()
    
    out = model(data.x, data.adj_t)
    y_pred = out.argmax(axis=-1)
    
    correct = y_pred == data.y
    train_acc = correct[data.train_mask].sum().float()/data.train_mask.sum()
    valid_acc = correct[data.val_mask].sum().float()/data.val_mask.sum()
    test_acc = correct[data.test_mask].sum().float()/data.test_mask.sum()
    
    return train_acc, valid_acc, test_acc 

#跑10个epoch看一下模型效果
for epoch in range(20):
    loss = train()
    train_acc, valid_acc, test_acc = test()
    print(f'Epoch: {
      
      epoch:02d}, '
                              f'Loss: {
      
      loss:.4f}, '
                              f'Train_acc: {
      
      100 * train_acc:.3f}%, '
                              f'Valid_acc: {
      
      100 * valid_acc:.3f}% '
                              f'Test_acc: {
      
      100 * test_acc:.3f}%')

quote

The article refers to:

  1. https://blog.csdn.net/weixin_39925939/article/details/121343538
  2. https://zhuanlan.zhihu.com/p/79637787
  3. https://zhuanlan.zhihu.com/p/336195862

Guess you like

Origin blog.csdn.net/Dajian1040556534/article/details/130084414