数据集加载——dataset和dataloader

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        return self.X[index], self.y[index]

MyDataset 是一个自定义的 PyTorch 数据集类,继承自 torch.utils.data.Dataset,用于加载并处理图像数据集。

在 PyTorch 中,数据集通常被表示为继承自 torch.utils.data.Dataset 的类,该类需要实现两个方法:__len__ 和 __getitem__。其中,__len__ 方法返回数据集中样本的数量,__getitem__ 方法按索引返回一个样本。这使得我们可以使用 PyTorch 的 DataLoader 来迭代数据集,并将数据批量地输入到神经网络中进行训练或预测。

在 MyDataset 类中,我们需要实现两个方法:

1.__init__ 方法:该方法初始化数据集,并加载数据集中的图像和标签。在该方法中,我们可以使用 Python 的文件操作或第三方库(如 PIL)来读取图像,并使用 NumPy 数组来存储它们。为了方便处理,我们可以将图像数据转换为 PyTorch 张量,并将标签转换为整数。

2.__getitem__ 方法:该方法根据给定的索引返回一个样本,其中包括图像和对应的标签。在该方法中,我们需要根据索引从图像数据集和标签数据集中获取对应的图像和标签,并对它们进行预处理(如归一化、调整大小等)。最后,我们将预处理后的图像和标签返回为 PyTorch 张量。

然后,我们可以使用 DataLoader 类来创建批次数据。以下是一个简单的例子:

dataset = MyDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

"dataset" 包含了训练数据,由特征向量 X_train 和相应的目标标签 y_train 组成。

DataLoader 中的 batch_size 参数指定了每个小批量中包含的样本数量。在这个例子中,每个小批量包含 32 个样本。

shuffle 参数设置为 True,这意味着在将样本划分为小批量之前,会随机地对数据集中的样本进行洗牌。这是一种常见的技术,用于防止模型过度拟合样本在数据集中的顺序。

猜你喜欢

转载自blog.csdn.net/weixin_50752408/article/details/129652747
今日推荐