[pytorch 入门] Pytorch读取数据

版权声明:本文为博主CSDN Rosefun96原创文章。 https://blog.csdn.net/rosefun96/article/details/87947590

简介

最近都是看图像里边的语义分割部分内容,比较有趣,同时入门Pytorch。Pytorch的主要特点是基本上所有操作都是用类来进行封装,本身自带很多类,而且你也可以根据官方的类进行修改。

1 数据导入

数据导入,本来Pytorch就有好几个类进行实现,分别是 DataSet, DataLoader, DataLoaderIter等。
以下是我用的一种方法。
首先我的数据是存在data_dir里边,每个子文件夹作为一类。

data_dir = '/Ryoma/data/'
from torchvision import transforms

transform = transforms.Compose([
    # you can add other transformations in this list
    transforms.ToTensor()
])

train_sets = datasets.ImageFolder(data_dir, transform)
train_loader = torch.utils.data.DataLoader(train_sets, batch_size=10, 
                                           shuffle=True, num_workers=4)
print(train_loader)
inputs, classes = next(iter(train_loader))
# Visualize a few images
def imshow(inp, title=None):
    """Imshow for Tensor."""
    print(inputs.shape)
    inp = inp[0]
    inp = inp.numpy().transpose((1, 2, 0))
#     mean = np.array([0.485, 0.456, 0.406])
#     std = np.array([0.229, 0.224, 0.225])
#     inp = std * inp + mean
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
imshow(inputs)

划分数据集
如果需要对数据集进行划分,可以采用以下方法:

   num_train = len(train_dataset)
   indices = list(range(num_train))
   split = int(np.floor(valid_size * num_train))

   if shuffle:
       np.random.seed(random_seed)
       np.random.shuffle(indices)

   train_idx, valid_idx = indices[split:], indices[:split]
   train_sampler = SubsetRandomSampler(train_idx)
   valid_sampler = SubsetRandomSampler(valid_idx)

   train_loader = torch.utils.data.DataLoader(
       train_dataset, batch_size=batch_size, sampler=train_sampler,
       num_workers=num_workers, pin_memory=pin_memory,
   )
   valid_loader = torch.utils.data.DataLoader(
       valid_dataset, batch_size=batch_size, sampler=valid_sampler,
       num_workers=num_workers, pin_memory=pin_memory,
   )

2.DataLoader

采用DataLoader是更加高效的方法。首先先编辑Dataset类,使得能够读取一张照片,然后,利用DataLoader进行批次读取。

import torch
from torch.utils import data

class Dataset(data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, list_IDs, labels):
        'Initialization'
        self.labels = labels
        self.list_IDs = list_IDs

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ID = self.list_IDs[index]

        # Load data and get label
        X = torch.load('data/' + ID + '.pt')
        y = self.labels[ID]

        return X, y

然后,

import torch
from torch.utils import data

from my_classes import Dataset


# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
cudnn.benchmark = True

# Parameters
params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 6}
max_epochs = 100

# Datasets
partition = # IDs
labels = # Labels

# Generators
training_set = Dataset(partition['train'], labels)
training_generator = data.DataLoader(training_set, **params)

validation_set = Dataset(partition['validation'], labels)
validation_generator = data.DataLoader(validation_set, **params)

# Loop over epochs
for epoch in range(max_epochs):
    # Training
    for local_batch, local_labels in training_generator:
        # Transfer to GPU
        local_batch, local_labels = local_batch.to(device), local_labels.to(device)

        # Model computations
        [...]

    # Validation
    with torch.set_grad_enabled(False):
        for local_batch, local_labels in validation_generator:
            # Transfer to GPU
            local_batch, local_labels = local_batch.to(device), local_labels.to(device)

            # Model computations
            [...]

参考:
1 知乎 Pytorch数据读取;
2 csdn Pytorch读取数据
3 github RuntimeError: Found 0 images in subfolders
4 Pytorch官网 TORCHVISION.DATASETS
5 github 数据划分的方法
6 斯坦福大学 并行读取数据.

猜你喜欢

转载自blog.csdn.net/rosefun96/article/details/87947590