Reference list:
[1] In- depth understanding of graph attention mechanism
[2] DGL official learning tutorial one-basic operations & message passing
[3] Cora data set introduction + python reading
1. DGL implements GAT classification machine learning papers
The program is excerpted from [1], which realizes the use of graph neural network framework- DGL to realize graph attention network (GAT). The application demo is to classify the category of the papers on the machine learning paper data set-Cora. (The picture below is taken from [3])
1. Procedure
Ubuntu:18.04
cuda:11.1
cudnn:8.0.4.30
pytorch:1.7.0
networkx:2.5
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
self.fc = nn.Linear(in_dim, out_dim, bias=False)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
def edge_attention(self, edges):
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {
'e' : F.leaky_relu(a)}
def message_func(self, edges):
return {
'z' : edges.src['z'], 'e' : edges.data['e']}
def reduce_func(self, nodes):
alpha = F.softmax(nodes.mailbox['e'], dim=1)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {
'h' : h}
def forward(self, h):
z = self.fc(h) # eq. 1
self.g.ndata['z'] = z
self.g.apply_edges(self.edge_attention) # eq. 2
self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4
return self.g.ndata.pop('h')
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATLayer(g, in_dim, out_dim))
self.merge = merge
def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat':
return torch.cat(head_outs, dim=1)
else:
return torch.mean(torch.stack(head_outs))
class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__()
self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
def forward(self, h):
h = self.layer1(h)
h = F.elu(h)
h = self.layer2(h)
return h
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
def load_core_data():
data = citegrh.load_cora()
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
g = DGLGraph(data.graph)
return g, features, labels, mask
import time
import numpy as np
g, features, labels, mask = load_core_data()
net = GAT(g, in_dim = features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(300):
if epoch >= 3:
t0 = time.time()
logits = net(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))
2. Notes
2.1 Two ways to initialize a graph
For the data structure as shown in the figure below:
0->1
1->2
3->1
Multi-called parentheses way
import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph((torch.tensor([0, 1, 3]), torch.tensor([1, 2, 1]))) # 小括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]]) #使用nx绘制,设置节点大小及灰度值
plt.show()
Or square brackets:
import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph([torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([3, 1])]) # 中括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]]) #使用nx绘制,设置节点大小及灰度值
plt.show()
Note: In the same graph, the position of each node printed each time is random.
2.2 The update_all
actual working process of DGL functions
Use the following routine description:
import networkx as nx
import matplotlib.pyplot as plt
import torch
import dgl
N = 100 # number of nodes
DAMP = 0.85 # damping factor阻尼因子
K = 10 # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) #图随机生成器,生成nx图
g = dgl.DGLGraph(g) #转换成DGL图
g.ndata['pv'] = torch.ones(N) / N #初始化PageRank值
g.ndata['deg'] = g.in_degrees(g.nodes()).float() #初始化节点特征
print(g.ndata['deg'])
#定义message函数,它将每个节点的PageRank值除以其out-degree,并将结果作为消息传递给它的邻居:
def pagerank_message_func(edges):
return {
'pv' : edges.src['pv'] / edges.src['deg']}
#定义reduce函数,它从mailbox中删除并聚合message,并计算其新的PageRank值:
def pagerank_reduce_func(nodes):
print("-batch size--pv size-------------")
print(nodes.batch_size(), nodes.mailbox['pv'].size())
msgs = torch.sum(nodes.mailbox['pv'], dim=1)
pv = (1 - DAMP) / N + DAMP * msgs
return {
'pv' : pv}
g.update_all(pagerank_message_func, pagerank_reduce_func)
Print g.ndata['deg'] information (that is, the in-degree information of each node) as follows:
tensor([ 9., 7., 17., 10., 12., 13., 13., 9., 5., 14., 7., 12., 15., 6.,
15., 7., 13., 7., 11., 9., 9., 15., 9., 12., 10., 8., 10., 9.,
15., 7., 8., 10., 10., 8., 11., 13., 6., 10., 10., 11., 5., 13.,
6., 12., 12., 8., 6., 11., 9., 10., 12., 8., 11., 5., 7., 12.,
4., 7., 8., 13., 11., 14., 9., 10., 12., 10., 10., 9., 10., 13.,
7., 15., 15., 10., 6., 11., 4., 6., 5., 10., 9., 11., 19., 9.,
12., 13., 15., 12., 12., 11., 10., 8., 11., 9., 7., 7., 11., 3.,
10., 5.])
pagerank_reduce_func
The print information in the function is as follows:
-batch size–pv size-------------
1 torch.Size([1, 3])
-batch size–pv size-------------
2 torch.Size([2, 4])
-batch size–pv size-------------
5 torch.Size([5, 5])
-batch size–pv size-------------
6 torch.Size([6, 6])
-batch size–pv size-------------
10 torch.Size([10, 7])
-batch size–pv size-------------
7 torch.Size([7, 8])
-batch size–pv size-------------
12 torch.Size([12, 9])
-batch size–pv size-------------
16 torch.Size([16, 10])
-batch size–pv size-------------
11 torch.Size([11, 11])
-batch size–pv size-------------
11 torch.Size([11, 12])
-batch size–pv size-------------
8 torch.Size([8, 13])
-batch size–pv size-------------
2 torch.Size([2, 14])
-batch size–pv size-------------
7 torch.Size([7, 15])
-batch size–pv size-------------
1 torch.Size([1, 17])
-batch size–pv size-------------
1 torch.Size([1, 19])
There is only one node with an in-degree of 3, two nodes with an in-degree of 4, and five nodes with an in-degree of 5,...
Comparing the in-degree information of the pagerank_reduce_func
graph with the printed information in the function, we find that there is only one node with an in-degree of 3, two nodes with an in-degree of 4, and five nodes with an in-degree of 5, so we get:
1) The function update_all
does not update all nodes together;
2) The function updates the nodes update_all
with the same number of target nodes together to form a batch, which is why the reduce_func(nodes)
input parameter type in the input parameter is dgl.udf.NodeBatch s reason. reduce_func(nodes)
The different rows of the input parameter nodes represent data related to different nodes.