PYG教程【四】Node2Vec节点分类及其可视化

本文主要是介绍如何用PyTorch Geometric快速实现Node2Vec节点分类,并对其结果进行可视化。

整个过程包含四个步骤:

  • 导入图数据(这里以Cora为例)
  • 创建Node2Vec模型
  • 训练和测试数据
  • TSNE降维后可视化

Node2vec方法的参数如下:

  • edge_index (LongTensor):邻接矩阵
  • embedding_dim (int):每个节点的embedding维度
  • walk_length (int):步长
  • context_size (int):正采样时的窗口大小
  • walks_per_node (int, optional) :每个节点走多少步
  • p (float, optional) :p值
  • q (float, optional) :q值
  • num_negative_samples (int, optional) :每个正采样对应多少负采样

代码如下:

import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import Node2Vec

dataset = Planetoid(root='G:/torch_geometric_datasets', name='Cora')
data = dataset[0]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=20,
                 context_size=10, walks_per_node=10, num_negative_samples=1,
                 sparse=True).to(device)
loader = model.loader(batch_size=128, shuffle=True, num_workers=4)

# 在pytorch旧版本中使用torch.optim.SparseAdam(model.parameters(), lr=0.01),新版本中需要转为list, 本文pytorch版本1.7.1
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

def train():
    model.train()
    total_loss = 0
    for pos_rw, neg_rw in loader:
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def test():
    model.eval()
    z = model()
    acc = model.test(z[data.train_mask], data.y[data.train_mask],z[data.test_mask], data.y[data.test_mask], max_iter=150) # 使用train_mask训练一个分类器,用test_mask分类
    return acc


for epoch in range(1, 101):
    loss = train()
    acc = test()
    print(f'Epoch:{
      
      epoch:02d}, Loss: {
      
      loss:.4f}, Acc: {
      
      acc:.4f}')

@torch.no_grad()
def plot_points(colors):
    model.eval()
    z = model(torch.arange(data.num_nodes, device=device))
    z = TSNE(n_components=2).fit_transform(z.cpu().numpy())
    y = data.y.cpu().numpy()

    plt.figure(figsize=(8, 8))
    for i in range(dataset.num_classes):
        plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])
    plt.axis('off')
    plt.show()


colors = ['#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700']
plot_points(colors)

输出结果如下:

Epoch:01, Loss: 8.0661, Acc: 0.1570
Epoch:02, Loss: 6.0309, Acc: 0.1800
Epoch:03, Loss: 4.9328, Acc: 0.2050
Epoch:04, Loss: 4.1206, Acc: 0.2400
Epoch:05, Loss: 3.4587, Acc: 0.2760
Epoch:06, Loss: 2.9389, Acc: 0.2950
Epoch:07, Loss: 2.5340, Acc: 0.3220
Epoch:08, Loss: 2.2042, Acc: 0.3410
Epoch:09, Loss: 1.9404, Acc: 0.3700
Epoch:10, Loss: 1.7295, Acc: 0.4050
Epoch:11, Loss: 1.5594, Acc: 0.4340
Epoch:12, Loss: 1.4231, Acc: 0.4660
Epoch:13, Loss: 1.3143, Acc: 0.4850
Epoch:14, Loss: 1.2242, Acc: 0.5100
Epoch:15, Loss: 1.1539, Acc: 0.5310
Epoch:16, Loss: 1.0997, Acc: 0.5560
Epoch:17, Loss: 1.0559, Acc: 0.5760
Epoch:18, Loss: 1.0199, Acc: 0.6020
Epoch:19, Loss: 0.9921, Acc: 0.6120
Epoch:20, Loss: 0.9671, Acc: 0.6190
Epoch:21, Loss: 0.9487, Acc: 0.6300
Epoch:22, Loss: 0.9335, Acc: 0.6390
Epoch:23, Loss: 0.9203, Acc: 0.6480
Epoch:24, Loss: 0.9106, Acc: 0.6580
Epoch:25, Loss: 0.8994, Acc: 0.6630
Epoch:26, Loss: 0.8924, Acc: 0.6600
Epoch:27, Loss: 0.8858, Acc: 0.6610
Epoch:28, Loss: 0.8792, Acc: 0.6670
Epoch:29, Loss: 0.8731, Acc: 0.6800
Epoch:30, Loss: 0.8697, Acc: 0.6830
Epoch:31, Loss: 0.8652, Acc: 0.6850
Epoch:32, Loss: 0.8618, Acc: 0.6840
Epoch:33, Loss: 0.8586, Acc: 0.6920
Epoch:34, Loss: 0.8550, Acc: 0.6900
Epoch:35, Loss: 0.8523, Acc: 0.6820
Epoch:36, Loss: 0.8507, Acc: 0.6800
Epoch:37, Loss: 0.8483, Acc: 0.6870
Epoch:38, Loss: 0.8469, Acc: 0.6930
Epoch:39, Loss: 0.8449, Acc: 0.6950
Epoch:40, Loss: 0.8433, Acc: 0.6920
Epoch:41, Loss: 0.8422, Acc: 0.6980
Epoch:42, Loss: 0.8398, Acc: 0.6960
Epoch:43, Loss: 0.8401, Acc: 0.6930
Epoch:44, Loss: 0.8374, Acc: 0.6930
Epoch:45, Loss: 0.8377, Acc: 0.6990
Epoch:46, Loss: 0.8363, Acc: 0.6970
Epoch:47, Loss: 0.8354, Acc: 0.7060
Epoch:48, Loss: 0.8339, Acc: 0.7130
Epoch:49, Loss: 0.8333, Acc: 0.7060
Epoch:50, Loss: 0.8340, Acc: 0.7090
Epoch:51, Loss: 0.8332, Acc: 0.7090
Epoch:52, Loss: 0.8325, Acc: 0.7090
Epoch:53, Loss: 0.8321, Acc: 0.7070
Epoch:54, Loss: 0.8316, Acc: 0.7160
Epoch:55, Loss: 0.8317, Acc: 0.7100
Epoch:56, Loss: 0.8297, Acc: 0.7130
Epoch:57, Loss: 0.8309, Acc: 0.7140
Epoch:58, Loss: 0.8296, Acc: 0.7230
Epoch:59, Loss: 0.8296, Acc: 0.7230
Epoch:60, Loss: 0.8276, Acc: 0.7190
Epoch:61, Loss: 0.8287, Acc: 0.7120
Epoch:62, Loss: 0.8294, Acc: 0.7120
Epoch:63, Loss: 0.8272, Acc: 0.7050
Epoch:64, Loss: 0.8286, Acc: 0.7040
Epoch:65, Loss: 0.8283, Acc: 0.7090
Epoch:66, Loss: 0.8278, Acc: 0.7110
Epoch:67, Loss: 0.8274, Acc: 0.7140
Epoch:68, Loss: 0.8283, Acc: 0.7190
Epoch:69, Loss: 0.8269, Acc: 0.7160
Epoch:70, Loss: 0.8271, Acc: 0.7210
Epoch:71, Loss: 0.8260, Acc: 0.7190
Epoch:72, Loss: 0.8273, Acc: 0.7130
Epoch:73, Loss: 0.8252, Acc: 0.7150
Epoch:74, Loss: 0.8264, Acc: 0.7120
Epoch:75, Loss: 0.8250, Acc: 0.7160
Epoch:76, Loss: 0.8253, Acc: 0.7190
Epoch:77, Loss: 0.8244, Acc: 0.7220
Epoch:78, Loss: 0.8263, Acc: 0.7220
Epoch:79, Loss: 0.8271, Acc: 0.7180
Epoch:80, Loss: 0.8253, Acc: 0.7110
Epoch:81, Loss: 0.8260, Acc: 0.7080
Epoch:82, Loss: 0.8246, Acc: 0.7140
Epoch:83, Loss: 0.8256, Acc: 0.7170
Epoch:84, Loss: 0.8257, Acc: 0.7210
Epoch:85, Loss: 0.8256, Acc: 0.7190
Epoch:86, Loss: 0.8244, Acc: 0.7170
Epoch:87, Loss: 0.8254, Acc: 0.7240
Epoch:88, Loss: 0.8249, Acc: 0.7170
Epoch:89, Loss: 0.8252, Acc: 0.7160
Epoch:90, Loss: 0.8243, Acc: 0.7010
Epoch:91, Loss: 0.8254, Acc: 0.7050
Epoch:92, Loss: 0.8249, Acc: 0.7030
Epoch:93, Loss: 0.8249, Acc: 0.7110
Epoch:94, Loss: 0.8233, Acc: 0.6990
Epoch:95, Loss: 0.8243, Acc: 0.6990
Epoch:96, Loss: 0.8248, Acc: 0.7140
Epoch:97, Loss: 0.8240, Acc: 0.7090
Epoch:98, Loss: 0.8247, Acc: 0.7100
Epoch:99, Loss: 0.8255, Acc: 0.7060
Epoch:100, Loss: 0.8242, Acc: 0.7160

在这里插入图片描述
从输出结果看出train的loss后面降低,但是精度却没有降低,有点过拟合了。

Guess you like

Origin blog.csdn.net/vincent_duan/article/details/121395566