制作pytorch数据集

方法1、自定义的dataset类

需要实现必要的魔法方法:

  • __init__魔法方法里面进行读取数据文件
  • __getitem__魔法方法进行支持下标访问
  • __len__魔法方法返回自定义数据集的大小,方便后期遍历

面已经定义好了抽象数据,只需给出自己的dataset和idxs(数据的索引列表))

from torch.utils.data import DataLoader, Dataset
class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]
 
    def __len__(self):
        return len(self.idxs)
 
    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.as_tensor(image), torch.as_tensor(label)
train_loader = DataLoader(DatasetSplit(train_dataset, client_idxs),
                          batch_size=args.local_bs, shuffle=True)

上面的train_dataset是你的数据集,client_idx是你的数据的索引列表,比如[1,2,345,33,54...........],数字代表数据在dataset中的位置。这样制作后的数据集就是client_idx索引的数据集。

 

方法2:直接使用torch.utils.data.TensorDataset()封装数据集

#划分数据集
import torch
import numpy as np
import torch.utils.data as Data
from sklearn.model_selection import train_test_split
x_train, x_test, y_train,y_test = train_test_split(feature, labels, test_size=0.25)
#制作pytorch识别的数据集
train_dataset = Data.TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train))
test_dataset = Data.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test))
#制作可迭代的数据集
train_iter = Data.DataLoader(dataset = train_dataset,batch_size = batch_size,
                             shuffle = True, num_workers = 2)
test_iter = Data.DataLoader(dataset = test_dataset, batch_size= batch_size,
                           shuffle = True, num_workers = 2) 

 

 

 

 

Guess you like

Origin blog.csdn.net/stay_zezo/article/details/115596443