pytorch读取自己的数据集(分类文件夹加载)—ImageFolder()

马上本科毕业了,毕业设计内容和图像深度学习有关,数据集在网上找了很久才找到,找到的数据集按类别存放在相应文件夹、没有标签文件。我不知道怎么读取,在CSDN上找了好久,只有很少的文章里提到了文件夹读取,但是也没有详细说明,大多数都是用标签读取的。
而且CSDN上将深度学习入门的博客大都直接用MNIST等一些可以直接使用的数据集,很少讲用自己的数据集的。
现在我的问题解决了,自己写一篇来帮助一些和我一样遇到问题的人。

  • 首先,把图像数据集放在你创建的python文件中,我这里的maize就是我的图像数据集
    在这里插入图片描述
  • 数据集文件里按train、valid和test分好,我的每个里面的分为(0、1、2、3)4个类别。
    在这里插入图片描述
    在这里插入图片描述
  • 文件夹读取代码,运用ImageFolder()和DataLoader()
from __future__ import print_function, division
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

valid_transform=transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

train_dataset =torchvision.datasets.ImageFolder(root='maize/train',transform=train_transform)
train_loader =DataLoader(train_dataset,batch_size=1, shuffle=True,num_workers=0)#Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。

valid_dataset =torchvision.datasets.ImageFolder(root='maize/valid',transform=valid_transform)
valid_loader =DataLoader(valid_dataset,batch_size=1, shuffle=True,num_workers=0)

这样就把数据放到train_loader和valid_loader里面了,接下来就用train_loader和valid_loader来调用数据就行,后面就可以参考CSDN上深度学习的文章了。

补充另一种数据集加载的形式:

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'maize'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=1)
              for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes

希望对大家有帮助!有没写清楚的地方可以留言或者私信我,我看到了都会回复的哦!

猜你喜欢

转载自blog.csdn.net/zzy_pphz/article/details/104711382