Dataset 基类
torch.utils.data.Dataset 为数据集的基类, 继承这个基类,我们能够非常快速的实现对数据的加载。
我们要实现自己加载数据的类,并继承于Dataset 这个类,重载类的成员函数
1、__1en__方法, 能够实现通过全局的len()方法获取其中的元素个数;
2、getitem 方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第 i条数据。
from torch.utils.data import Dataset, DataLoader
# 完成数据集类
class MyDataset(Dataset):
def __init__(self):
def __getitem__(self, index):
""" 必须实现,作用是:获取索引对应位置的一条数据 :param index: :return: """
return to_tensor(self.data[index])
def __len__(self):
""" 必须实现,作用是得到数据集的大小 :return: """
return len(self.data)
def to_tensor(data):
return torch.from_numpy(data)
使用Dataset 能够进行数据的读取,但是还需要如下实现:
批处理数据(Batching the data)
打乱数据(Shuffling the data)
使用多线程multiprocessing并行加载数据
定义好 Dataset 之后就可以用DataLoader进行加载。
DataLoader 调用一句话即可,dataset 指向 自定义的读取数据类。
data_loader = DataLoader(dataset=data, batch_size=2, shuffle=True, num_workers=2)
参数:
1、dataset:提前定义的dataset的实例;
2、batch_size:传入数据的batch大小,常常是32、64
3、shuffle:bool类型,打乱数据;
4、num_workers:加载数据的线程数。
5、drop_last:bool类型,为真,表示最后的数据不足一个batch,就删掉
迭代遍历:
for step, (batch_x, batch_y) in enumerate(data_loader):
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))