グラフニューラルネットワークフレームワークDGLはグラフアテンションネットワーク(GAT)ノートを実装します

参照リスト:

[1]グラフの注意メカニズムの詳細な理解
[2] DGL公式学習チュートリアル1-基本的な操作とメッセージパッシング
[3] Coraデータセットの紹介+ Pythonの読み取り

1.DGLはGAT分類機械学習ペーパーを実装します

このプログラムは、グラフニューラルネットワークフレームワーク-DGLを使用してグラフアテンションネットワーク(GAT)を実現する[1]から抜粋したものです。アプリケーションのデモは、機械学習の紙のデータセットであるCoraの紙のカテゴリを分類することです。(下の写真は[3]から取られたものです)
ここに画像の説明を挿入

1.手順

Ubuntu:18.04
cuda:11.1
cudnn:8.0.4.30
pytorch:1.7.0
networkx:2.5

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
    def edge_attention(self, edges):
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {
    
    'e' : F.leaky_relu(a)}
    
    def message_func(self, edges):
        return {
    
    'z' : edges.src['z'], 'e' : edges.data['e']}
    
    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {
    
    'h' : h}
    def forward(self, h):
        z = self.fc(h) # eq. 1
        self.g.ndata['z'] = z 
        self.g.apply_edges(self.edge_attention) # eq. 2
        self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4
        return self.g.ndata.pop('h')
    
    
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge
        
    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            return torch.cat(head_outs, dim=1)
        else:
            return torch.mean(torch.stack(head_outs))
            
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
        
    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

from dgl import DGLGraph
from dgl.data import citation_graph as citegrh

def load_core_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.ByteTensor(data.train_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, mask

import time 
import numpy as np
g, features, labels, mask = load_core_data()

net = GAT(g, in_dim = features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(300):
    if epoch >= 3:
        t0 = time.time()
        
    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch >= 3:
        dur.append(time.time() - t0)
        
    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))

2.注意事項

2.1グラフを初期化する2つの方法

以下に示すデータ構造の場合:
0-> 1
1-> 2
3-> 1

複数の括弧の方法

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph((torch.tensor([0, 1, 3]), torch.tensor([1, 2, 1]))) # 小括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

ここに画像の説明を挿入
または角括弧:

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph([torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([3, 1])]) # 中括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

ここに画像の説明を挿入
注:同じグラフでは、毎回印刷される各ノードの位置はランダムです。

2.2DGLupdate_all機能実際の作業プロセス

次のルーチンの説明を使用します。

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dgl
 
N = 100  # number of nodes
DAMP = 0.85  # damping factor阻尼因子
K = 10  # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) #图随机生成器,生成nx图
g = dgl.DGLGraph(g)                 #转换成DGL图
g.ndata['pv'] = torch.ones(N) / N  #初始化PageRank值
g.ndata['deg'] = g.in_degrees(g.nodes()).float()  #初始化节点特征
print(g.ndata['deg'])
#定义message函数,它将每个节点的PageRank值除以其out-degree,并将结果作为消息传递给它的邻居:
def pagerank_message_func(edges):
    return {
    
    'pv' : edges.src['pv'] / edges.src['deg']}
#定义reduce函数,它从mailbox中删除并聚合message,并计算其新的PageRank值:
def pagerank_reduce_func(nodes):
    print("-batch size--pv size-------------")
    print(nodes.batch_size(), nodes.mailbox['pv'].size())
    msgs = torch.sum(nodes.mailbox['pv'], dim=1)
    pv = (1 - DAMP) / N + DAMP * msgs
    return {
    
    'pv' : pv}
g.update_all(pagerank_message_func, pagerank_reduce_func)

g.ndata ['deg']情報(つまり、各ノードの次数情報)を次のように出力します。

tensor([9.、7.、17.、10.、12.、13.、13.、9.、5.、14.、7.、12.、15.、6.、
15.、7。 、13.、7.、11.、9.、9.、15.、9.、12.、10.、8.、10.、9.、
15.、7.、8.、10.、10 。、8.、11.、13.、6.、10.、10.、11.、5.、13.、
6.、12.、12.、8.、6.、11.、9。、 10.、12.、8.、11.、5.、7.、12.、
4.、7.、8.、13.、11.、14.、9.、10.、12.、10。 、10.、9.、10.、13.、
7.、15.、15.、10.、6.、11.、4.、6.、5.、10.、9.、11.、19 。、9.、
12.、13.、15.、12.、12.、11.、10.、8.、11.、9.、7.、7.、11.、3.、
10。、 5.])

pagerank_reduce_func関数内の印刷情報は次のとおりです。

-バッチサイズ–pvサイズ-------------
1 torch.Size([1、3])-
バッチサイズ–pvサイズ-------------
2 torch.Size([2、4])-
batch size–pv size -------------
5 torch.Size([
5、5 ])- batch size–pv size ---- ---------
6 torch.Size([6、6])-
batch size–pv size -------------
10 torch.Size([10、7])
-バッチサイズ–pvサイズ-------------
7 torch.Size([7、8])-
バッチサイズ–pvサイズ-------------
12 torch.Size([12、9])-
batch size–pv size -------------
16 torch.Size([
16、10 ])- batch size–pv size ---- ---------
11 torch.Size([11、11])-
batch size–pv size -------------
11 torch.Size([11、12])
-バッチサイズ–pvサイズ-------------
8 torch.Size([8、13])-
batch size–pv size -------------
2 torch.Size([
2、14 ])- batch size–pv size --- ----------
7 torch.Size([7、15])-
batch size–pv size -------------
1 torch.Size([1、17] )
-batch size–pv size -------------
1 torch.Size([1、19])

インディグリーが3のノードは1つだけで、インディグリーが4のノードは2つ、インディグリーが5のノードは5つです。

pagerank_reduce_funcグラフの度数情報を関数内の印刷された情報と比較すると、度数3のノードが1つだけ、度数4のノードが2つ、および度数が5つのノードがあることがわかります。次数が5であるため、次のようになります
。1)関数update_allは、すべてのノードをまとめて更新するわけではありません
。2)関数update_allは、同じ数のターゲットノードを持つノードをまとめて更新してバッチを形成しreduce_func(nodes)ます。 dgl.udf.NodeBatchの理由です。reduce_func(nodes)入力パラメータノードのさまざまな行は、さまざまなノードに関連するデータを表します。

おすすめ

転載: blog.csdn.net/u013468614/article/details/115329460