深度学习:数据加载

0.前言

框架:pytorch

1.直接从网上下载数据集,用于模型的测试

    tsf = transforms.Compose([transforms.ToTensor()])
    train_data = datasets.FashionMNIST(root='./data/fashionmnist/train',train=True,transform=tsf,download=True)
    test_data = datasets.FashionMNIST(root='./data/fashionmnist/test',train=False,transform=tsf,download=True)
    train_iter = DataLoader(train_data,batch_size=32,shuffle=True)
    test_iter = DataLoader(test_data,batch_size=32,shuffle=False)

2.从本地文件夹中加载

    tsf = transforms.Compose([transforms.Resize((28, 28)),
                              transforms.ToTensor()])
    train_data = datasets.ImageFolder(root=r'image\dogcat\dogcat\train',transform=tsf)
    train_iter = DataLoader(train_data,batch_size=32,shuffle=True)

本地文件夹的目录结构
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_42911863/article/details/126256390