序文
ノートのこのセクションでは、主に継承に焦点を当てInMemoryDataset
、一度にすべてのデータをメモリにロードします。この種のデータセットは通常それほど大きくないため、一度に直接ロードされます
データセットを作成する
1、データセット
pytorch geometric
データセットの構築には2つのタイプがあり
ます。1.継承InMemoryDataset
、すべてのデータを一度にメモリにロード
2.継承Dataset
、段階的にメモリにロード
カスタムDataset
初期化メソッドで、データストレージパスを入力し、pytorch geometric
このパスの下の2つのフォルダーを分割します
。1. raw_dir:生データを保存するパス(通常はcsv、matなどの形式)
2 . processing_dir :処理されたデータ(通常はpt形式、プロセスメソッドによって実装されます)
を格納しますが、pytorchでは実際には2つのフォルダーはありません
公式文書を見てください:
https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets
InMemoryDataset
関数はサンプルコードの2行目に導入されています。最初に、この関数とそのパラメータの使用方法を見てみましょう
2. InMemoryDatasetの解釈
root
データセットストレージのルートディレクトリですtansform
pre_transform
同じと異なる、同じことが、それはデータの変換後のバージョンへのデータとリターンを受け入れるために使用されていることであり、差があることをtansform
各訪問前変換はpre_transform
、ディスクに保存する前に変換することですpre_filter
データを受け入れ、データオブジェクトを最終的なデータセットに格納する必要があるかどうかを示すブール値を返すために使用される関数です
3.公式文書の例
戻ってコードを見てください。コードのコメントに手順を統合しました。さらに、ビデオの説明があまり明確でない場所がいくつかあります。ハンズオングラフニューラルネットワークとPyTorchおよびPyTorch Geometricを組み合わせて、コメントと自分自身を追加しました理解。
# 官方代码 https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets
import torch
from torch_geometric.data import InMemoryDataset # https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html CLASS InMemoryDataset
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None): # 初始化函数
super(MyOwnDataset, self).__init__(root, transform, pre_transform) # super用于说明MyOwnDataset继承InMemoryDataset初始化结果
self.data, self.slices = torch.load(self.processed_paths[0]) # 详见说明1
@property # 修饰方法,使方法可以像属性一样访问(保护变量/只读函数转变)详见说明2
def raw_file_names(self): # 返回一个包含没有处理的数据的名字的list
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self): # 返回一个包含所有处理过的数据的list
return ['data.pt']
def download(self): # 下载数据集函数,不需要的话直接填充pass
# Download to `self.raw_dir`.
# 整合你的数据成一个包含data的list,然后调用 self.collate()去计算将用于 DataLodadr 的片段
def process(self):
# Read data into huge `Data` list.
data_list = [...] # 创建并读取了数据的列表
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)] # 判断数据对象是否应该保存
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list] # 保存到磁盘之前进行转化
data, slices = self.collate(data_list)# 将数据对象的python列表整理为内部存储格式 torch_geometric.data.InMemoryDataset
torch.save((data, slices), self.processed_paths[0])
1.説明1:
この部分はpytorch_geometric自作のデータセットを参照してい
ます。データセットは次のように定義しdata
、必要がありますslices
。
data
pytorch_geometric
定義されたデータ型でData
構築されたグラフデータセットを指します。slices
それは、異なるスライスを指すgraph
の分割データセット、例えば、それがslices[‘x’]=[0,75,150]
指す区分75個のノード、3つのグラフの合計に応じてデータ・セットの、slices[‘y’]
そしてslices['edge_index ']
ように。slices
異なる機能graph
と実装を区別するために使用されshuffle
ます。型slices
の必要性int
がtensor
なければ、DataLoader
スライス操作をサポートしないことは注目に値します。
2.説明2:
この部分はpython @propertyの紹介と使用について言及してい
ます。この記事の説明は、アップマスターが見つけた分析よりも理解しやすいと思います。すでに非常に簡潔です。私の記事からは抜粋しません。慎重に見ることができます。
Amazonコード例
# https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/datasets/amazon.py
import torch
from torch_geometric.data import InMemoryDataset, download_url # download_url为了下载数据
from torch_geometric.io import read_npz
class Amazon(InMemoryDataset):
r"""The Amazon Computers and Amazon Photo networks from the
`"Pitfalls of Graph Neural Network Evaluation"
<https://arxiv.org/abs/1811.05868>`_ paper.
Nodes represent goods and edges represent that two goods are frequently
bought together.
Given product reviews as bag-of-words node features, the task is to
map goods to their respective product category.
Args:
root (string): Root directory where the dataset should be saved.
name (string): The name of the dataset (:obj:`"Computers"`,
:obj:`"Photo"`).
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
"""
url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/'
def __init__(self, root, name, transform=None, pre_transform=None):
self.name = name.lower() # lower将字符串所有大小转小写
assert self.name in ['computers', 'photo'] # 利用断言判断 name 值的范围是不是在 computers/photo 范围内
super(Amazon, self).__init__(root, transform, pre_transform) # 继承初始化值
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return 'amazon_electronics_{}.npz'.format(self.name)
@property
def processed_file_names(self):
return 'data.pt'
def download(self):
download_url(self.url + self.raw_file_names, self.raw_dir)
def process(self):
data = read_npz(self.raw_paths[0]) # 读取 npz 格式数据集
data = data if self.pre_transform is None else self.pre_transform(data)
data, slices = self.collate([data])
torch.save((data, slices), self.processed_paths[0])
def __repr__(self):
return '{}{}()'.format(self.__class__.__name__, self.name.capitalize())
基本的には公式サイトドキュメントの構造と同じですが、処理内容が若干異なります。実際、この場所を見たとき、たとえばコードがself.processed_paths[0]
定義されておらず割り当てられていないので、私はすでに少し混乱していました。なぜそれを直接呼び出すことができるのでしょうか。
疑問のこの部分は後で回答され、その後戻ってメモを変更し続けます