Graph marco de red neuronal DGL implementa notas de Graph Attention Network (GAT)

Lista de referencia:

[1] Comprensión en profundidad del mecanismo de atención de gráficos
[2] Tutorial de aprendizaje oficial de DGL: operaciones básicas y transmisión de mensajes
[3] Introducción al conjunto de datos de Cora + lectura de Python

1. DGL implementa documentos de aprendizaje automático de clasificación GAT

El programa está extraído de [1], que realiza el uso de un marco de red neuronal de gráficos - DGL para realizar una red de atención de gráficos (GAT). La demostración de la aplicación sirve para clasificar la categoría de los artículos en el conjunto de datos de papel de aprendizaje automático: Cora. (La siguiente imagen está tomada de [3])
Inserte la descripción de la imagen aquí

1. Procedimiento

Ubuntu: 18.04
cuda: 11.1
cudnn: 8.0.4.30
pytorch: 1.7.0
networkx: 2.5

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
    def edge_attention(self, edges):
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {
    
    'e' : F.leaky_relu(a)}
    
    def message_func(self, edges):
        return {
    
    'z' : edges.src['z'], 'e' : edges.data['e']}
    
    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {
    
    'h' : h}
    def forward(self, h):
        z = self.fc(h) # eq. 1
        self.g.ndata['z'] = z 
        self.g.apply_edges(self.edge_attention) # eq. 2
        self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4
        return self.g.ndata.pop('h')
    
    
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge
        
    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            return torch.cat(head_outs, dim=1)
        else:
            return torch.mean(torch.stack(head_outs))
            
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
        
    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

from dgl import DGLGraph
from dgl.data import citation_graph as citegrh

def load_core_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.ByteTensor(data.train_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, mask

import time 
import numpy as np
g, features, labels, mask = load_core_data()

net = GAT(g, in_dim = features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(300):
    if epoch >= 3:
        t0 = time.time()
        
    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch >= 3:
        dur.append(time.time() - t0)
        
    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))

2. Notas

2.1 Dos formas de inicializar un gráfico

Para la estructura de datos como se muestra a continuación:
0-> 1
1-> 2
3-> 1

Multi-llamada paréntesis camino

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph((torch.tensor([0, 1, 3]), torch.tensor([1, 2, 1]))) # 小括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

Inserte la descripción de la imagen aquí
O corchetes:

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph([torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([3, 1])]) # 中括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

Inserte la descripción de la imagen aquí
Nota: En el mismo gráfico, la posición de cada nodo que se imprime cada vez es aleatoria.

2.2 El update_allproceso de trabajo real de las funciones DGL

Utilice la siguiente descripción de rutina:

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dgl
 
N = 100  # number of nodes
DAMP = 0.85  # damping factor阻尼因子
K = 10  # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) #图随机生成器,生成nx图
g = dgl.DGLGraph(g)                 #转换成DGL图
g.ndata['pv'] = torch.ones(N) / N  #初始化PageRank值
g.ndata['deg'] = g.in_degrees(g.nodes()).float()  #初始化节点特征
print(g.ndata['deg'])
#定义message函数,它将每个节点的PageRank值除以其out-degree,并将结果作为消息传递给它的邻居:
def pagerank_message_func(edges):
    return {
    
    'pv' : edges.src['pv'] / edges.src['deg']}
#定义reduce函数,它从mailbox中删除并聚合message,并计算其新的PageRank值:
def pagerank_reduce_func(nodes):
    print("-batch size--pv size-------------")
    print(nodes.batch_size(), nodes.mailbox['pv'].size())
    msgs = torch.sum(nodes.mailbox['pv'], dim=1)
    pv = (1 - DAMP) / N + DAMP * msgs
    return {
    
    'pv' : pv}
g.update_all(pagerank_message_func, pagerank_reduce_func)

Imprima la información de g.ndata ['deg'] (es decir, la información en grados de cada nodo) de la siguiente manera:

tensor ([9., 7., 17., 10., 12., 13., 13., 9., 5., 14., 7., 12., 15., 6.,
15., 7. , 13., 7., 11., 9., 9., 15., 9., 12., 10., 8., 10., 9.,
15., 7., 8., 10., 10 ., 8., 11., 13., 6., 10., 10., 11., 5., 13.,
6., 12., 12., 8., 6., 11., 9., 10., 12., 8., 11., 5., 7., 12.,
4., 7., 8., 13., 11., 14., 9., 10., 12., 10. , 10., 9., 10., 13.,
7., 15., 15., 10., 6., 11., 4., 6., 5., 10., 9., 11., 19 ., 9.,
12., 13., 15., 12., 12., 11., 10., 8., 11., 9., 7., 7., 11., 3.,
10., 5.])

pagerank_reduce_funcLa información de impresión en la función es la siguiente:

-Tamaño de lote – Tamaño de PV -------------
1 Antorcha. Tamaño ([1, 3])
-Tamaño de lote – Tamaño de PV -------------
2 antorcha.Tamaño ([2, 4])
-tamaño del lote – tamaño pv -------------
5 antorcha.Tamaño ([5, 5])
-tamaño del lote – tamaño pv ---- ---------
6 antorcha.Tamaño ([6, 6])
-tamaño de lote – tamaño pv -------------
10 antorcha.Tamaño ([10, 7])
-Tamaño del lote – Tamaño del PV -------------
7 Tamaño de la antorcha ([7, 8])
-Tamaño del lote – Tamaño del PV -------------
12 antorcha.Tamaño ([12, 9])
-tamaño del lote – tamaño pv -------------
16 antorcha.Tamaño ([16, 10])
-tamaño del lote – tamaño pv ---- ---------
11 antorcha.Tamaño ([11, 11])
-tamaño de lote – tamaño pv -------------
11 antorcha.Tamaño ([11, 12])
-tamaño del lote – tamaño pv -------------
8 antorcha.Tamaño ([8, 13])
-tamaño del lote – tamaño pv -------------
2 antorcha.Tamaño ([2, 14])
-tamaño del lote – tamaño pv --- ----------
7 antorcha.Tamaño ([7, 15])
-tamaño de lote – tamaño pv -------------
1 antorcha.Tamaño ([1, 17] )
-tamaño de lote – tamaño pv -------------
1 antorcha.Tamaño ([1, 19])

Solo hay un nodo con un grado interno de 3, dos nodos con un grado interno de 4 y cinco nodos con un grado interno de 5, ...

Comparando la información en grados del pagerank_reduce_funcgráfico con la información impresa en la función, encontramos que solo hay un nodo con un grado interno de 3, dos nodos con un grado interno de 4 y cinco nodos con un grado interno grado de 5, por lo que obtenemos:
1) La función update_allno actualiza todos los nodos juntos;
2) La función actualiza los nodos update_allcon el mismo número de nodos de destino juntos para formar un lote, por lo que el reduce_func(nodes)tipo de parámetro de entrada en el parámetro de entrada es el motivo de dgl.udf.NodeBatch. reduce_func(nodes)Las diferentes filas de los nodos de los parámetros de entrada representan datos relacionados con diferentes nodos.

Supongo que te gusta

Origin blog.csdn.net/u013468614/article/details/115329460
Recomendado
Clasificación