コードの観点から、グラフニューラルネットワークの一連のノートの詳細な説明(2)

序文

ノートのこのセクションでは、主に継承に焦点を当てInMemoryDataset、一度にすべてのデータをメモリにロードします。この種のデータセットは通常それほど大きくないため、一度に直接ロードされます

データセットを作成する

1、データセット

pytorch geometricデータセットの構築には2つのタイプがあり
ます。1.継承InMemoryDataset、すべてのデータを一度にメモリにロード
2.継承Dataset、段階的にメモリにロード

カスタムDataset初期化メソッドで、データストレージパスを入力し、pytorch geometricこのパスの下の2つのフォルダーを分割します
。1. raw_dir:生データを保存するパス(通常はcsv、matなどの形式)
2 . processing_dir :処理されたデータ(通常はpt形式、プロセスメソッドによって実装されます)
を格納しますが、pytorchでは実際には2つのフォルダーはありません

公式文書を見てください:

https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets

ここに画像の説明を挿入
InMemoryDataset関数はサンプルコードの2行目に導入されています。最初に、この関数とそのパラメータの使用方法を見てみましょう

2. InMemoryDatasetの解釈

ここに画像の説明を挿入

  • rootデータセットストレージのルートディレクトリです
  • tansformpre_transform同じと異なる、同じことが、それはデータの変換後のバージョンへのデータとリターンを受け入れるために使用されていることであり、差があることをtansform各訪問前変換はpre_transform、ディスクに保存する前に変換することです
  • pre_filter データを受け入れ、データオブジェクトを最終的なデータセットに格納する必要があるかどうかを示すブール値を返すために使用される関数です

3.公式文書の例

戻ってコードを見てください。コードのコメントに手順を統合しました。さらに、ビデオの説明があまり明確でない場所がいくつかあります。ハンズオングラフニューラルネットワークとPyTorchおよびPyTorch Geometricを組み合わせて、コメントと自分自身追加しました理解。

# 官方代码 https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets

import torch
from torch_geometric.data import InMemoryDataset # https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html  CLASS InMemoryDataset


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None): # 初始化函数
        super(MyOwnDataset, self).__init__(root, transform, pre_transform) # super用于说明MyOwnDataset继承InMemoryDataset初始化结果
        self.data, self.slices = torch.load(self.processed_paths[0]) # 详见说明1

    @property # 修饰方法,使方法可以像属性一样访问(保护变量/只读函数转变)详见说明2
    def raw_file_names(self): # 返回一个包含没有处理的数据的名字的list
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self): # 返回一个包含所有处理过的数据的list
        return ['data.pt']

    def download(self): # 下载数据集函数,不需要的话直接填充pass
        # Download to `self.raw_dir`.

    # 整合你的数据成一个包含data的list,然后调用 self.collate()去计算将用于 DataLodadr 的片段
    def process(self):
        # Read data into huge `Data` list.
        data_list = [...] # 创建并读取了数据的列表

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)] # 判断数据对象是否应该保存

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list] # 保存到磁盘之前进行转化

        data, slices = self.collate(data_list)# 将数据对象的python列表整理为内部存储格式 torch_geometric.data.InMemoryDataset
        torch.save((data, slices), self.processed_paths[0])

1.説明1:

この部分はpytorch_geometric自作のデータセットを参照してい
ます。データセットは次のように定義しdata必要がありますslices

  • datapytorch_geometric定義されたデータ型でData構築されたグラフデータセットを指します。
  • slicesそれは、異なるスライスを指すgraphの分割データセット、例えば、それがslices[‘x’]=[0,75,150]指す区分75個のノード、3つのグラフの合計に応じてデータ・セットの、slices[‘y’]そしてslices['edge_index ']ように。slices異なる機能graphと実装を区別するために使用さshuffleます。slicesの必要性inttensorなければ、DataLoaderスライス操作をサポートしないことは注目に値します。

2.説明2:

この部分はpython @propertyの紹介と使用について言及してい
ますこの記事の説明は、アップマスターが見つけた分析よりも理解しやすいと思います。すでに非常に簡潔です。私の記事からは抜粋しません。慎重に見ることができます。

Amazonコード例

# https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/datasets/amazon.py
import torch
from torch_geometric.data import InMemoryDataset, download_url # download_url为了下载数据
from torch_geometric.io import read_npz 


class Amazon(InMemoryDataset):
    r"""The Amazon Computers and Amazon Photo networks from the
    `"Pitfalls of Graph Neural Network Evaluation"
    <https://arxiv.org/abs/1811.05868>`_ paper.
    Nodes represent goods and edges represent that two goods are frequently
    bought together.
    Given product reviews as bag-of-words node features, the task is to
    map goods to their respective product category.
    Args:
        root (string): Root directory where the dataset should be saved.
        name (string): The name of the dataset (:obj:`"Computers"`,
            :obj:`"Photo"`).
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
    """

    url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/'

    def __init__(self, root, name, transform=None, pre_transform=None):
        self.name = name.lower() # lower将字符串所有大小转小写
        assert self.name in ['computers', 'photo'] # 利用断言判断 name 值的范围是不是在 computers/photo 范围内
        super(Amazon, self).__init__(root, transform, pre_transform) # 继承初始化值
        self.data, self.slices = torch.load(self.processed_paths[0])  

    @property
    def raw_file_names(self):
        return 'amazon_electronics_{}.npz'.format(self.name)

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        download_url(self.url + self.raw_file_names, self.raw_dir)

    def process(self):
        data = read_npz(self.raw_paths[0]) # 读取 npz 格式数据集
        data = data if self.pre_transform is None else self.pre_transform(data)
        data, slices = self.collate([data])
        torch.save((data, slices), self.processed_paths[0])

    def __repr__(self):
        return '{}{}()'.format(self.__class__.__name__, self.name.capitalize())

基本的には公式サイトドキュメントの構造と同じですが、処理内容が若干異なります。実際、この場所を見たとき、たとえばコードがself.processed_paths[0]定義されておらず割り当てられていないので、私はすでに少し混乱していました。なぜそれを直接呼び出すことができるのでしょうか。

疑問のこの部分は後で回答され、その後戻ってメモを変更し続けます

おすすめ

転載: blog.csdn.net/wy_97/article/details/108547022