[Pytorch学习]数据加载以及处理

前言

汲取自pytorch-DATA LOADING AND PROCESSING TUTORIAL

这里主要介绍了数据集的处理,从类的构造角度阐述了如何自己打造各个函数,不过到最后还是给出了pytorch自带的包,给我们省了不少事

包的导入

Dataset class

torch.utils.data.Dataset 是一个表示数据集的抽象类
我们自定义的dataset需要继承Dataset并且重载以下的方法

  • __len__ ,即可调用len(dataset)返回数据集的大小
  • __getitem__ ,即可使用dataset[i]访问数据集

例子:

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

Transforms

许多神经网络需要同样大小的图像,因此,我们需要写一些预处理的代码来。比如

  • Rescale: 缩放图片
  • RandomCrop: 随机裁剪图片
  • ToTensor: 把numpy图片转换成torch图片

注意写类的时候要定义__call__,这样才方便当成一个函数调用
样例代码:

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

Compose transforms

torchvision.transforms.Compose 是一个简单的可调用的类,可以组合一系列的transform操作
比如

composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

再进行composed(image)就可对image进行操作

Iterating through the dataset

torch.utils.data.DataLoader提供如下功能:

  • Batching the data #对几对数据打包
  • shuffling the data #随机打乱数据
  • Load the data in parallel using multiprocessing workers #使用多进程手段加载数据

用法:
dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)

torchvision

在本节中,我们已经学会如何使用dataset, transform, dataloader
torchvision包提供了一些常用的dataset和transforms
例如:

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
		# 这些都是torchvision.transforms自带的变换函数
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
# torchvision.dataset自带的数据集
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

猜你喜欢

转载自blog.csdn.net/crabstew/article/details/89012370