PyTorch Geometric(PyG)解读、快速开始-简单易懂

PyG官方github

https://github.com/rusty1s/pytorch_geometric

torch_geometric.data

创建Data对象

这个模块包含了一个叫Data的类,可以创建Data对象
创建只需要:
节点的属性/特征(the attributes/features associated with each node, node features)
邻接/边连接信息(the connectivity/adjacency of each node, edge index)

假设现在有一个图包含三个节点,每个节点用一个特征向量表示,特征向量分别为f1、f2、f3

​x = torch.tensor([f1,f2,f3], dtype=torch.float) #节点的特征向量构成的特征矩阵
y = torch.tensor([0,1,0], dtype=torch.float)#每个节点归属的类别,这里三个节点分别归属于0,1,0类

边集可以被表示为:
边集以COCO格式存储
边集矩阵大小为2*E,E的大小就是有向边的总条数
矩阵的第一行是源节点的标号,第二行是目标节点的标号

​edge_index = torch.tensor([[0,1,2,0,3],
                          [1,0,1,3,2]],dtype=torch.long)

此处存储的边的顺序并不重要
边上的权值(可选参数,非必要):

edge_attr (Tensor, optional): Edge weights or multi-dimensional
    edge features. (default: :obj:`None`)

创建Data对象的完整示例:

import torch

from torch_geometric.data import Data

x = torch.tensor([[2,1],[5,6],[3,7],[12,0]],dtype=torch.float)

y = torch.tensor([[0,2,1,0,3],[3,1,0,1,2]],dtype=torch.long)

​edge_index = torch.tensor([[0,1,2,0,3],
                          [1,0,1,3,2]],dtype=torch,long)

data = Data(x=x,y=y,edge_index=edge_index)

快速开始

有了data对象就可以快速开始了,PyG官方提供了许多图神经网络算法的接口
例如
在这里插入图片描述
可根据需要快速开始,示例

from torch_geometric.nn import GCNConv
in_channels=10
out_channels=5
#in_channels (int) – Size of each input sample.
#out_channels (int) – Size of each output sample.
conv1 = GCNConv(_channels, out_channels, cached=True)
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
out=conv1(x, edge_index, edge_weight)                             

out即为使用GCNConv卷积之后的结果

创建Dataset

官方关于创建Dataset的指南

https://github.com/rusty1s/pytorch_geometric/blob/a01dc15d5a879e0054f81f611a0dfb2a68ee9424/docs/source/notes/create_dataset.rst

PyG提供两种不同的数据集类:

1·InMemoryDataset
2·Dataset

可以理解为第一种数据集较小,在内存中可存下。第二种数据集较大,首先介绍第一种也就是InMemoryDataset

Raw_file_names()

它返回一个包含没有处理的数据的名字的list

Processed_file_names()

返回一个包含所有处理过的数据的list。在调用process()这个函数后,通常返回的list只有一个元素,它只保存已经处理过的数据的名字。

Download()

这个函数下载数据到你正在工作的目录中,你可以在self.raw_dir中指定。如果你不需要下载数据,你可以在这函数中写一个pass

Process()

这是Dataset中最重要的函数。你需要整合你的数据成一个包含data的list。然后调用 self.collate()去计算将用DataLodadr的片段
官方示例

import torch
from torch_geometric.data import InMemoryDataset


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

持续更新PyG相关内容,欢迎关注、留言

原创文章 32 获赞 4 访问量 7644

猜你喜欢

转载自blog.csdn.net/yrwang_xd/article/details/105349722
今日推荐