PyTorch Geometric的Mini-batches

官方文档 链接

加载ENZYMES数据集

from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader


dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)

loader = DataLoader(dataset, batch_size=4, shuffle=True)

ENZYMES数据集

batch

获取一个batch

batch = loader.__iter__().next()
print(batch)
# Batch(batch=[169], edge_index=[2, 556], ptr=[5], x=[169, 21], y=[4])

由于batch_size=4,所以batch中有4个图。batch的属性如图所示:
在这里插入图片描述

batch.keys
# ['x', 'edge_index', 'y', 'batch', 'ptr']

batch[0].keys
# ['x', 'edge_index', 'y']

取出单个数据

for i in range(batch.num_graphs):
    print(batch[i])
"""
Data(edge_index=[2, 178], x=[50, 21], y=[1])
Data(edge_index=[2, 114], x=[30, 21], y=[1])
Data(edge_index=[2, 160], x=[60, 21], y=[1])
Data(edge_index=[2, 104], x=[29, 21], y=[1])
"""

ptr属性

注意ptr这个属性,如果要把batch中的4个图取出来需要这个属性。
在这里插入图片描述

  • batch[0]就是[0:50] 50-0=50
  • batch[1]就是[50:80] 80-50=30
  • batch[2]就是[80:140] 140-80=60
  • batch[3]就是[140:169] 169-140=29

batch属性

输出batch属性查看一下
在这里插入图片描述
发现连续50个0,30个1,60个2,29个3

batch是怎么区分数据包括哪些的

batch.__slices__
""
{
    
    'y': [0, 1, 2, 3, 4], 
'x': [0, 50, 80, 140, 169], 
'edge_index': [0, 178, 292, 452, 556]}
""

获取batch[0]的时候,根据batch.__slices__

  • batch[0]['y'] = batch['y'][ batch.__slices__['y'][0]:batch.__slices__['y'][0+1] ]
  • batch[0]['x'] = batch['x'][ batch.__slices__['x'][0]:batch.__slices__['x'][0+1] ]
  • batch[0]['edge_index'] = batch['edge_index'][ batch.__slices__['edge_index'][0]:batch.__slices__['edge_index'][0+1] ]

获取batch[1]batch[2]、… 、batch[n]的时候,只用将 0 0 0改为相应的下标即可

PyTorch Geometric(PYG)-实现小批量data类中__inc__与__cat_dim__的含义与作用

https://blog.csdn.net/qq_41795143/article/details/114281387

猜你喜欢

转载自blog.csdn.net/qq_37252519/article/details/119357519
今日推荐