Desde el punto de vista del código, explicación en profundidad de la serie de notas de la red neuronal gráfica (2)

Prefacio

Esta sección de las notas se centra principalmente en la herencia InMemoryDataset, cargando todos los datos en la memoria a la vez, este tipo de conjunto de datos generalmente no es muy grande, por lo que se carga directamente a la vez.

Construye el conjunto de datos

1 、 Conjunto de datos

pytorch geometricHay dos tipos de construcción de conjuntos de datos:
1. Herencia InMemoryDataset, carga todos los datos en la memoria a la vez.
2. Herencia Dataset, carga en la memoria por etapas.

En el Datasetmétodo de inicialización personalizado , ingrese la ruta de almacenamiento de datos y pytorch geometricluego divida 2 carpetas bajo esta ruta:
1. raw_dir: la ruta para almacenar los datos sin procesar (generalmente en formato csv, mat, etc.)
2. procesado_dir: Almacene los datos procesados ​​(generalmente en formato pt, implementado por el método de proceso),
pero en Pytorch, en realidad no hay dos carpetas

Mira los documentos oficiales:

https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets

Inserte la descripción de la imagen aquí
La InMemoryDatasetfunción se introduce en la segunda línea del código de muestra . Primero, veamos el uso de esta función y sus parámetros.

2. Interpretación de InMemoryDataset

Inserte la descripción de la imagen aquí

  • rootEs el directorio raíz para el almacenamiento de conjuntos de datos.
  • tansformLo pre_transformmismo y diferente, lo mismo es que se usa para aceptar los datos y volver a la versión convertida de los datos; la diferencia es que tansformla conversión antes de cada visita pre_transformes convertir antes de guardar en disco
  • pre_filter Es una función que se utiliza para aceptar datos y devolver un valor booleano para indicar si el objeto de datos debe almacenarse en el conjunto de datos final.

3. Ejemplos de documentos oficiales

Regrese y continúe mirando el código. He integrado las instrucciones en los comentarios del código. Además, hay algunos lugares donde la explicación en el video no es muy clara. Combiné el artículo Hands-on Graph Neural Networks con PyTorch & PyTorch Geometric para agregar algunos comentarios y a mí mismo Comprensión.

# 官方代码 https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets

import torch
from torch_geometric.data import InMemoryDataset # https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html  CLASS InMemoryDataset


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None): # 初始化函数
        super(MyOwnDataset, self).__init__(root, transform, pre_transform) # super用于说明MyOwnDataset继承InMemoryDataset初始化结果
        self.data, self.slices = torch.load(self.processed_paths[0]) # 详见说明1

    @property # 修饰方法,使方法可以像属性一样访问(保护变量/只读函数转变)详见说明2
    def raw_file_names(self): # 返回一个包含没有处理的数据的名字的list
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self): # 返回一个包含所有处理过的数据的list
        return ['data.pt']

    def download(self): # 下载数据集函数,不需要的话直接填充pass
        # Download to `self.raw_dir`.

    # 整合你的数据成一个包含data的list,然后调用 self.collate()去计算将用于 DataLodadr 的片段
    def process(self):
        # Read data into huge `Data` list.
        data_list = [...] # 创建并读取了数据的列表

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)] # 判断数据对象是否应该保存

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list] # 保存到磁盘之前进行转化

        data, slices = self.collate(data_list)# 将数据对象的python列表整理为内部存储格式 torch_geometric.data.InMemoryDataset
        torch.save((data, slices), self.processed_paths[0])

1. Descripción 1:

Esta parte se refiere al conjunto de datos de creación propia de pytorch_geometric
. El conjunto de datos debe definirse datay slices:

  • dataSe refiere a un conjunto de datos de gráficos construido con un pytorch_geometrictipo de datos definido Data;
  • slicesSe refiere a rodajas, es decir, diferentes graphdivisiones en el conjunto de datos . Por ejemplo, se slices[‘x’]=[0,75,150]refiere a la división del conjunto de datos de acuerdo con 75 nodos, un total de tres gráficos, slices[‘y’]y slices['edge_index ']así sucesivamente. slicesSe utiliza para distinguir entre diferentes funciones graphe implementaciones shuffle. Es de destacar que slicesla necesidad intde tensortipo, de lo contrario DataLoader, no admite operaciones de corte.

2. Descripción 2:

Esta parte se refiere a la introducción y uso de python @property.
Creo que la explicación en este artículo es más fácil de entender que el análisis encontrado por el maestro de up. Ya es muy conciso. No lo extraeré de mi artículo. Puedes verlo con atención.

Ejemplo de código de Amazon

# https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/datasets/amazon.py
import torch
from torch_geometric.data import InMemoryDataset, download_url # download_url为了下载数据
from torch_geometric.io import read_npz 


class Amazon(InMemoryDataset):
    r"""The Amazon Computers and Amazon Photo networks from the
    `"Pitfalls of Graph Neural Network Evaluation"
    <https://arxiv.org/abs/1811.05868>`_ paper.
    Nodes represent goods and edges represent that two goods are frequently
    bought together.
    Given product reviews as bag-of-words node features, the task is to
    map goods to their respective product category.
    Args:
        root (string): Root directory where the dataset should be saved.
        name (string): The name of the dataset (:obj:`"Computers"`,
            :obj:`"Photo"`).
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
    """

    url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/'

    def __init__(self, root, name, transform=None, pre_transform=None):
        self.name = name.lower() # lower将字符串所有大小转小写
        assert self.name in ['computers', 'photo'] # 利用断言判断 name 值的范围是不是在 computers/photo 范围内
        super(Amazon, self).__init__(root, transform, pre_transform) # 继承初始化值
        self.data, self.slices = torch.load(self.processed_paths[0])  

    @property
    def raw_file_names(self):
        return 'amazon_electronics_{}.npz'.format(self.name)

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        download_url(self.url + self.raw_file_names, self.raw_dir)

    def process(self):
        data = read_npz(self.raw_paths[0]) # 读取 npz 格式数据集
        data = data if self.pre_transform is None else self.pre_transform(data)
        data, slices = self.collate([data])
        torch.save((data, slices), self.processed_paths[0])

    def __repr__(self):
        return '{}{}()'.format(self.__class__.__name__, self.name.capitalize())

Es básicamente la misma que la estructura dada por el documento del sitio web oficial, y algunos detalles de procesamiento son ligeramente diferentes. De hecho, cuando vi este lugar, ya estaba un poco confundido, porque por ejemplo, el código self.processed_paths[0]no está definido y asignado, por qué se puede llamar directamente.

Esta parte de las dudas será respondida más tarde para luego volver y seguir cambiando las notas

Supongo que te gusta

Origin blog.csdn.net/wy_97/article/details/108547022
Recomendado
Clasificación