pytorch自定义dataset

参考

一个例子

import torch
from torch.utils import data

class MyDataset(data.Dataset):
    def __init__(self):
        super(MyDataset, self).__init__()
        self.data = torch.randn(8,2)
    
    def __getitem__(self, index):
        return self.data[index], index
    
    def __len__(self):
        return self.data.size()[0]

data_set = MyDataset()
print(data_set.data)

输出
tensor([[-1.3907, -0.0916],
[-0.4626, -1.3323],
[ 1.4242, -2.1718],
[ 1.5850, 0.3320],
[-1.0804, 0.3884],
[ 0.6567, -0.1234],
[ 1.6721, -0.7327],
[-1.9595, -0.3512]])

data_loader = data.DataLoader(data_set, 
                              batch_size=4,
                              shuffle=False)
print(len(data_set))
for i, (number, labels) in enumerate(data_loader):
    print(number)

输出
8
tensor([[-1.3907, -0.0916],
[-0.4626, -1.3323],
[ 1.4242, -2.1718],
[ 1.5850, 0.3320]])
tensor([0, 1, 2, 3])
tensor([[-1.0804, 0.3884],
[ 0.6567, -0.1234],
[ 1.6721, -0.7327],
[-1.9595, -0.3512]])
tensor([4, 5, 6, 7])

猜你喜欢

转载自www.cnblogs.com/huanxifan/p/12656900.html
今日推荐