深度学习新星:图卷积神经网络GCN

深度学习新星:图卷积神经网络GCN

作者:金松

引言

深度学习一直都是被几大经典模型给统治着,如CNN、RNN等等,它们无论再CV还是NLP领域都取得了优异的效果,那这个GCN是怎么跑出来的?是因为我们发现了很多CNN、RNN无法解决或者效果不好的问题——图结构的数据。

图片或者语言,都属于欧式空间的数据,因此才有维度的概念,欧式空间的数据的特点就是结构很规则。但是现实生活中,其实有很多很多不规则的数据结构,典型的就是图结构,或称拓扑结构,如社交网络、化学分子结构、知识图谱等等;即使是语言,实际上其内部也是复杂的树形结构,也是一种图结构;而像图片,在做目标识别的时候,我们关注的实际上只是二维图片上的部分关键点,这些点组成的也是一个图的结构。

图的结构一般来说是十分不规则的,可以认为是无限维的一种数据,所以它没有平移不变性。每一个节点的周围结构可能都是独一无二的,这种结构的数据,就让传统的CNN、RNN瞬间失效。所以很多学者从上个世纪就开始研究怎么处理这类数据了。这里涌现出了很多方法,例如GNN、DeepWalk、node2vec等等,GCN只是其中一种,这里只讲GCN,其他的后面有空再讨论。

GCN,图卷积神经网络,实际上跟CNN的作用一样,就是一个特征提取器,只不过它的对象是图数据。GCN精妙地设计了一种从图数据中提取特征的方法,从而让我们可以使用这些特征去对图数据进行节点分类(node classification)、图分类(graph classification)、边预测(link prediction),还可以顺便得到图的嵌入表示(graph embedding),可见用途广泛。因此现在人们脑洞大开,让GCN到各个领域中发光发热。
本文会用最简单的GCN在拳击俱乐部社交网络上做分类任务,让没接触过的童鞋较快理解。

0. 问题描述

首先,简单介绍一下数据集。

Zachary’s Karate Club是一个描述大学空手道俱乐部成员社交关系的网络,由Wayne W. Zachary在论文《An Information Flow Model for Conflict and Fission in Small Groups》中提出,是一个常用的社交网络示例。这个空手道俱乐部包含34名成员,管理员 John A 和教官 Mr. Hi 之间的一次冲突导致这个俱乐部一分为二,一半的成员围绕着 Mr. Hi 成立了一个新俱乐部,另一半成员要么找到了新的教练,要么放弃了空手道。因此,在对应的社交网络中,节点也被划分为两个组,一组属于Mr. Hi (Instructor) ,另一组属于John A (Administrator),其中节点0代表Mr. Hi,节点33代表John A。

我们可以利用networkx直接获取Zachary’s Karate Club数据,此时管理员 John A 被称为Officer。任务是预测每个节点会加入哪一边(0or33)。对该社交网络的可视化如下:

1. 创建一张graph

首先创建关于拳击俱乐部的网络

import dgl
import numpy as np

def build_karate_club_graph():
    # All 78 edges are stored in two numpy arrays. One for source endpoints
    # while the other for destination endpoints.
    src = np.array([1, 2, 2, 3, 3, 3, 4, 5, 6, 6, 6, 7, 7, 7, 7, 8, 8, 9, 10, 10,
        10, 11, 12, 12, 13, 13, 13, 13, 16, 16, 17, 17, 19, 19, 21, 21,
        25, 25, 27, 27, 27, 28, 29, 29, 30, 30, 31, 31, 31, 31, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33,
        33, 33, 33, 33, 33, 33, 33, 33, 33, 33])
    dst = np.array([0, 0, 1, 0, 1, 2, 0, 0, 0, 4, 5, 0, 1, 2, 3, 0, 2, 2, 0, 4,
        5, 0, 0, 3, 0, 1, 2, 3, 5, 6, 0, 1, 0, 1, 0, 1, 23, 24, 2, 23,
        24, 2, 23, 26, 1, 8, 0, 24, 25, 28, 2, 8, 14, 15, 18, 20, 22, 23,
        29, 30, 31, 8, 9, 13, 14, 15, 18, 19, 20, 22, 23, 26, 27, 28, 29, 30,
        31, 32])
    # Edges are directional in DGL; Make them bi-directional.
    u = np.concatenate([src, dst])
    v = np.concatenate([dst, src])
    # Construct a DGLGraph
    return dgl.DGLGraph((u, v))

打印出新定义 Graph 的节点和边

G = build_karate_club_graph()
print('We have %d nodes.' % G.number_of_nodes())
print('We have %d edges.' % G.number_of_edges())

用 networkx 可视化新的graph

import networkx as nx
# Since the actual graph is undirected, we convert it for visualization
# purpose.
nx_G = G.to_networkx().to_undirected()
# Kamada-Kawaii layout usually looks pretty for arbitrary graphs
pos = nx.kamada_kawai_layout(nx_G)
nx.draw(nx_G, pos, with_labels=True, node_color=[[.7, .7, .7]])

2. 给边和节点赋予特征

图神经网络会联合节点和边的特征做训练。

在这个例子中,因为没有节点的特征,就用one-hot的embedding方法得到维度为5的特征

import torch
import torch.nn as nn
import torch.nn.functional as F

embed = nn.Embedding(34, 5)  # 34 nodes with embedding dim equal to 5
G.ndata['feat'] = embed.weight

打印出节点的特征来验证下

# print out node 2's input feature
print(G.ndata['feat'][2])

# print out node 10 and 11's input features
print(G.ndata['feat'][[10, 11]])

3. 定义一个图卷积网络

简单的定义一个图卷积神经网络框架。

  • 在第 $l $ 层,每个节点 v i l v_i^l 用一个节点向量 h i l h_i^l 表示;
  • GCN的每一层的目的是聚合每一个节点 v i l v_i^{l} 的邻居节点们 u i u_i 用来生成下一层的向量表示 v i l + 1 v_i^{l+1} ,然后接一个非线性的激活函数。

上面整个步骤可以看作一个message-passing的范式:每个节点会接受邻居节点的信息从而更新自身的节点表示。一个图形化的例子就是:

DGL库提供了 GCN 层的实现

from dgl.nn.pytorch import GraphConv

定义了包含了两个GCN层的GCN模型

class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, hidden_size)
        self.conv2 = GraphConv(hidden_size, num_classes)

    def forward(self, g, inputs):
        h = self.conv1(g, inputs)
        h = torch.relu(h)
        h = self.conv2(g, h)
        return h

# The first layer transforms input features of size of 5 to a hidden size of 5.
# The second layer transforms the hidden layer and produces output features of
# size 2, corresponding to the two groups of the karate club.
net = GCN(5, 5, 2)

4. 准备数据 & 初始化

使用one-hot向量初始化节点。因为是一个半监督的设定,仅有指导员(节点0)和俱乐部主席(节点33)被分配了label,实现如下:

inputs = embed.weight
labeled_nodes = torch.tensor([0, 33])  # only the instructor and the president nodes are labeled
labels = torch.tensor([0, 1])  # their labels are different

5. 训练 & 可视化展示

训练的步骤和PyTorch模型一样

  • 创建优化器,
  • 输入input数据,
  • 计算loss,
  • 使用反向传播优化模型
import itertools

optimizer = torch.optim.Adam(itertools.chain(net.parameters(), embed.parameters()), lr=0.01)
all_logits = []
for epoch in range(50):
    logits = net(G, inputs)
    # we save the logits for visualization later
    all_logits.append(logits.detach())
    logp = F.log_softmax(logits, 1)
    # we only compute loss for labeled nodes
    loss = F.nll_loss(logp[labeled_nodes], labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))

这是一个非常简单的小例子,甚至没有划分验证集和测试集。因此,因为模型最后输出了每个节点的二维向量,我们可以轻易的在2D的空间将这个过程可视化出来,下面的代码动态的展示了训练过程中从开始的状态到到最后所有节点都线性可分的过程。

import matplotlib.animation as animation
import matplotlib.pyplot as plt

def draw(i):
    cls1color = '#00FFFF'
    cls2color = '#FF00FF'
    pos = {}
    colors = []
    for v in range(34):
        pos[v] = all_logits[i][v].numpy()
        cls = pos[v].argmax()
        colors.append(cls1color if cls else cls2color)
    ax.cla()
    ax.axis('off')
    ax.set_title('Epoch: %d' % i)
    nx.draw_networkx(nx_G.to_undirected(), pos, node_color=colors,
            with_labels=True, node_size=300, ax=ax)

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()
draw(0)  # draw the prediction of the first epoch
plt.close()

下面的动态过程展示了模型经过一段训练之后能够准确预测节点属于哪个群组。

ani = animation.FuncAnimation(fig, draw, frames=len(all_logits), interval=200)

项目实战链接:https://momodel.cn/workspace/5e8b3a29142d1d72944d121f/app

参考文献:

##关于我们
Mo(网址:https://momodel.cn) 是一个支持 Python的人工智能在线建模平台,能帮助你快速开发、训练并部署模型。

近期 Mo 也在持续进行机器学习相关的入门课程和论文分享活动,欢迎大家关注我们的公众号获取最新资讯!

发布了36 篇原创文章 · 获赞 4 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/weixin_44015907/article/details/105362963
今日推荐