参照リスト:
[1]グラフの注意メカニズムの詳細な理解
[2] DGL公式学習チュートリアル1-基本的な操作とメッセージパッシング
[3] Coraデータセットの紹介+ Pythonの読み取り
1.DGLはGAT分類機械学習ペーパーを実装します
このプログラムは、グラフニューラルネットワークフレームワーク-DGLを使用してグラフアテンションネットワーク(GAT)を実現する[1]から抜粋したものです。アプリケーションのデモは、機械学習の紙のデータセットであるCoraの紙のカテゴリを分類することです。(下の写真は[3]から取られたものです)
1.手順
Ubuntu:18.04
cuda:11.1
cudnn:8.0.4.30
pytorch:1.7.0
networkx:2.5
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
self.fc = nn.Linear(in_dim, out_dim, bias=False)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
def edge_attention(self, edges):
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {
'e' : F.leaky_relu(a)}
def message_func(self, edges):
return {
'z' : edges.src['z'], 'e' : edges.data['e']}
def reduce_func(self, nodes):
alpha = F.softmax(nodes.mailbox['e'], dim=1)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {
'h' : h}
def forward(self, h):
z = self.fc(h) # eq. 1
self.g.ndata['z'] = z
self.g.apply_edges(self.edge_attention) # eq. 2
self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4
return self.g.ndata.pop('h')
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATLayer(g, in_dim, out_dim))
self.merge = merge
def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat':
return torch.cat(head_outs, dim=1)
else:
return torch.mean(torch.stack(head_outs))
class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__()
self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
def forward(self, h):
h = self.layer1(h)
h = F.elu(h)
h = self.layer2(h)
return h
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
def load_core_data():
data = citegrh.load_cora()
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
g = DGLGraph(data.graph)
return g, features, labels, mask
import time
import numpy as np
g, features, labels, mask = load_core_data()
net = GAT(g, in_dim = features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(300):
if epoch >= 3:
t0 = time.time()
logits = net(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))
2.注意事項
2.1グラフを初期化する2つの方法
以下に示すデータ構造の場合:
0-> 1
1-> 2
3-> 1
複数の括弧の方法
import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph((torch.tensor([0, 1, 3]), torch.tensor([1, 2, 1]))) # 小括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]]) #使用nx绘制,设置节点大小及灰度值
plt.show()
または角括弧:
import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph([torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([3, 1])]) # 中括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]]) #使用nx绘制,设置节点大小及灰度值
plt.show()
注:同じグラフでは、毎回印刷される各ノードの位置はランダムです。
2.2DGLupdate_all
機能の実際の作業プロセス
次のルーチンの説明を使用します。
import networkx as nx
import matplotlib.pyplot as plt
import torch
import dgl
N = 100 # number of nodes
DAMP = 0.85 # damping factor阻尼因子
K = 10 # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) #图随机生成器,生成nx图
g = dgl.DGLGraph(g) #转换成DGL图
g.ndata['pv'] = torch.ones(N) / N #初始化PageRank值
g.ndata['deg'] = g.in_degrees(g.nodes()).float() #初始化节点特征
print(g.ndata['deg'])
#定义message函数,它将每个节点的PageRank值除以其out-degree,并将结果作为消息传递给它的邻居:
def pagerank_message_func(edges):
return {
'pv' : edges.src['pv'] / edges.src['deg']}
#定义reduce函数,它从mailbox中删除并聚合message,并计算其新的PageRank值:
def pagerank_reduce_func(nodes):
print("-batch size--pv size-------------")
print(nodes.batch_size(), nodes.mailbox['pv'].size())
msgs = torch.sum(nodes.mailbox['pv'], dim=1)
pv = (1 - DAMP) / N + DAMP * msgs
return {
'pv' : pv}
g.update_all(pagerank_message_func, pagerank_reduce_func)
g.ndata ['deg']情報(つまり、各ノードの次数情報)を次のように出力します。
tensor([9.、7.、17.、10.、12.、13.、13.、9.、5.、14.、7.、12.、15.、6.、
15.、7。 、13.、7.、11.、9.、9.、15.、9.、12.、10.、8.、10.、9.、
15.、7.、8.、10.、10 。、8.、11.、13.、6.、10.、10.、11.、5.、13.、
6.、12.、12.、8.、6.、11.、9。、 10.、12.、8.、11.、5.、7.、12.、
4.、7.、8.、13.、11.、14.、9.、10.、12.、10。 、10.、9.、10.、13.、
7.、15.、15.、10.、6.、11.、4.、6.、5.、10.、9.、11.、19 。、9.、
12.、13.、15.、12.、12.、11.、10.、8.、11.、9.、7.、7.、11.、3.、
10。、 5.])
pagerank_reduce_func
関数内の印刷情報は次のとおりです。
-バッチサイズ–pvサイズ-------------
1 torch.Size([1、3])-
バッチサイズ–pvサイズ-------------
2 torch.Size([2、4])-
batch size–pv size -------------
5 torch.Size([
5、5 ])- batch size–pv size ---- ---------
6 torch.Size([6、6])-
batch size–pv size -------------
10 torch.Size([10、7])
-バッチサイズ–pvサイズ-------------
7 torch.Size([7、8])-
バッチサイズ–pvサイズ-------------
12 torch.Size([12、9])-
batch size–pv size -------------
16 torch.Size([
16、10 ])- batch size–pv size ---- ---------
11 torch.Size([11、11])-
batch size–pv size -------------
11 torch.Size([11、12])
-バッチサイズ–pvサイズ-------------
8 torch.Size([8、13])-
batch size–pv size -------------
2 torch.Size([
2、14 ])- batch size–pv size --- ----------
7 torch.Size([7、15])-
batch size–pv size -------------
1 torch.Size([1、17] )
-batch size–pv size -------------
1 torch.Size([1、19])
インディグリーが3のノードは1つだけで、インディグリーが4のノードは2つ、インディグリーが5のノードは5つです。
pagerank_reduce_func
グラフの度数情報を関数内の印刷された情報と比較すると、度数3のノードが1つだけ、度数4のノードが2つ、および度数が5つのノードがあることがわかります。次数が5であるため、次のようになります
。1)関数update_all
は、すべてのノードをまとめて更新するわけではありません
。2)関数update_all
は、同じ数のターゲットノードを持つノードをまとめて更新してバッチを形成しreduce_func(nodes)
ます。 dgl.udf.NodeBatchの理由です。reduce_func(nodes)
入力パラメータノードのさまざまな行は、さまざまなノードに関連するデータを表します。