Graph neural network framework DGL implements Graph Attention Network (GAT) notes

Reference list:

[1] In- depth understanding of graph attention mechanism
[2] DGL official learning tutorial one-basic operations & message passing
[3] Cora data set introduction + python reading

1. DGL implements GAT classification machine learning papers

The program is excerpted from [1], which realizes the use of graph neural network framework- DGL to realize graph attention network (GAT). The application demo is to classify the category of the papers on the machine learning paper data set-Cora. (The picture below is taken from [3])
Insert picture description here

1. Procedure

Ubuntu:18.04
cuda:11.1
cudnn:8.0.4.30
pytorch:1.7.0
networkx:2.5

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
    def edge_attention(self, edges):
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {
    
    'e' : F.leaky_relu(a)}
    
    def message_func(self, edges):
        return {
    
    'z' : edges.src['z'], 'e' : edges.data['e']}
    
    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {
    
    'h' : h}
    def forward(self, h):
        z = self.fc(h) # eq. 1
        self.g.ndata['z'] = z 
        self.g.apply_edges(self.edge_attention) # eq. 2
        self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4
        return self.g.ndata.pop('h')
    
    
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge
        
    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            return torch.cat(head_outs, dim=1)
        else:
            return torch.mean(torch.stack(head_outs))
            
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
        
    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

from dgl import DGLGraph
from dgl.data import citation_graph as citegrh

def load_core_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.ByteTensor(data.train_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, mask

import time 
import numpy as np
g, features, labels, mask = load_core_data()

net = GAT(g, in_dim = features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(300):
    if epoch >= 3:
        t0 = time.time()
        
    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch >= 3:
        dur.append(time.time() - t0)
        
    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))

2. Notes

2.1 Two ways to initialize a graph

For the data structure as shown in the figure below:
0->1
1->2
3->1

Multi-called parentheses way

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph((torch.tensor([0, 1, 3]), torch.tensor([1, 2, 1]))) # 小括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

Insert picture description here
Or square brackets:

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph([torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([3, 1])]) # 中括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

Insert picture description here
Note: In the same graph, the position of each node printed each time is random.

2.2 The update_allactual working process of DGL functions

Use the following routine description:

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dgl
 
N = 100  # number of nodes
DAMP = 0.85  # damping factor阻尼因子
K = 10  # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) #图随机生成器,生成nx图
g = dgl.DGLGraph(g)                 #转换成DGL图
g.ndata['pv'] = torch.ones(N) / N  #初始化PageRank值
g.ndata['deg'] = g.in_degrees(g.nodes()).float()  #初始化节点特征
print(g.ndata['deg'])
#定义message函数,它将每个节点的PageRank值除以其out-degree,并将结果作为消息传递给它的邻居:
def pagerank_message_func(edges):
    return {
    
    'pv' : edges.src['pv'] / edges.src['deg']}
#定义reduce函数,它从mailbox中删除并聚合message,并计算其新的PageRank值:
def pagerank_reduce_func(nodes):
    print("-batch size--pv size-------------")
    print(nodes.batch_size(), nodes.mailbox['pv'].size())
    msgs = torch.sum(nodes.mailbox['pv'], dim=1)
    pv = (1 - DAMP) / N + DAMP * msgs
    return {
    
    'pv' : pv}
g.update_all(pagerank_message_func, pagerank_reduce_func)

Print g.ndata['deg'] information (that is, the in-degree information of each node) as follows:

tensor([ 9., 7., 17., 10., 12., 13., 13., 9., 5., 14., 7., 12., 15., 6.,
15., 7., 13., 7., 11., 9., 9., 15., 9., 12., 10., 8., 10., 9.,
15., 7., 8., 10., 10., 8., 11., 13., 6., 10., 10., 11., 5., 13.,
6., 12., 12., 8., 6., 11., 9., 10., 12., 8., 11., 5., 7., 12.,
4., 7., 8., 13., 11., 14., 9., 10., 12., 10., 10., 9., 10., 13.,
7., 15., 15., 10., 6., 11., 4., 6., 5., 10., 9., 11., 19., 9.,
12., 13., 15., 12., 12., 11., 10., 8., 11., 9., 7., 7., 11., 3.,
10., 5.])

pagerank_reduce_funcThe print information in the function is as follows:

-batch size–pv size-------------
1 torch.Size([1, 3])
-batch size–pv size-------------
2 torch.Size([2, 4])
-batch size–pv size-------------
5 torch.Size([5, 5])
-batch size–pv size-------------
6 torch.Size([6, 6])
-batch size–pv size-------------
10 torch.Size([10, 7])
-batch size–pv size-------------
7 torch.Size([7, 8])
-batch size–pv size-------------
12 torch.Size([12, 9])
-batch size–pv size-------------
16 torch.Size([16, 10])
-batch size–pv size-------------
11 torch.Size([11, 11])
-batch size–pv size-------------
11 torch.Size([11, 12])
-batch size–pv size-------------
8 torch.Size([8, 13])
-batch size–pv size-------------
2 torch.Size([2, 14])
-batch size–pv size-------------
7 torch.Size([7, 15])
-batch size–pv size-------------
1 torch.Size([1, 17])
-batch size–pv size-------------
1 torch.Size([1, 19])

There is only one node with an in-degree of 3, two nodes with an in-degree of 4, and five nodes with an in-degree of 5,...

Comparing the in-degree information of the pagerank_reduce_funcgraph with the printed information in the function, we find that there is only one node with an in-degree of 3, two nodes with an in-degree of 4, and five nodes with an in-degree of 5, so we get:
1) The function update_alldoes not update all nodes together;
2) The function updates the nodes update_allwith the same number of target nodes together to form a batch, which is why the reduce_func(nodes)input parameter type in the input parameter is dgl.udf.NodeBatch s reason. reduce_func(nodes)The different rows of the input parameter nodes represent data related to different nodes.

Guess you like

Origin blog.csdn.net/u013468614/article/details/115329460