DATA LOADING AND PROCESSING TUTORIAL

在解决任何机器学习问题时,都需要花费大量的精力来准备数据。PyTorch提供了许多工具来简化数据加载,希望能使代码更具可读性。在本教程中,我们将看到如何加载和预处理/增强非平凡数据集中的数据。

要运行本教程,请确保安装了以下软件包:

  • scikit-image:用于图像io和转换
  • pandas:用于更容易的csv解析
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

我们要处理的数据集是面部姿态数据集。这意味着一张脸被这样注释

总共有68个不同的地标点被标注在每张脸上。

注意

从这里下载数据集,以便图像位于名为data/faces/的目录中。该数据集实际上是通过对imagenet中标记为人脸的几幅图像应用出色的dlib姿态估计生成的。

Dataset附带一个csv文件,带有如下注释:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

让我们快速读取CSV并获得(N, 2)数组中的注释,其中N是地标的数量。

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

Out:

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

让我们编写一个简单的helper函数来显示图像及其地标,并使用它来显示示例。

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
               landmarks)
plt.show()

Dataset class

torch.utils.data.Dataset是表示数据集的抽象类。您的自定义数据集应该继承Dataset并覆盖以下方法:

  • __len__ 使len(dataset)返回数据集的大小。
  • __getitem__ 支持索引,这样dataset[i]就可以用来获取第i个样本

让我们为face landmarks dataset创建一个dataset类。我们将在_init__中读取csv,但将图像的读取留给_getitem__。这是一种内存效率,因为所有的图像不是一次存储在内存中,而是根据需要读取。

我们的数据集的样本将是一个dict {'image': image, 'landmarks': landmarks}。我们的数据集将接受一个可选的参数transform,以便任何所需的处理都可以应用于示例。我们将在下一节中看到transform的有用性。

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

让我们实例化这个类并遍历数据示例。我们将打印前4个样品的尺寸并显示它们的地标。

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]

    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break

Out:

0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

Transforms

从上面我们可以看到的一个问题是样品的尺寸不一样。大多数神经网络期望图像的大小是固定的。因此,我们需要编写一些预处理代码。让我们创建三个转换:

  • Rescale:缩放图像
  • RandomCrop:从图像中随机裁剪。这是数据扩充。
  • ToTensor:将numpy图像转换为torch图像(需要交换轴)。

我们将把它们写成可调用的类,而不是简单的函数,这样就不必每次调用转换时都传递它的参数。为此,我们只需要实现_call__方法,如果需要的话,还需要实现_init__方法。然后我们可以使用这样的变换:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

下面观察这些转换如何同时应用于图像和地标。

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

Compose transforms

现在,我们将转换应用于一个示例。

假设我们想要将图像的短边重新缩放到256,然后从中随机裁剪一个224的正方形。 即,我们想要组合Rescale和RandomCrop变换。 torchvision.transforms.Compose是一个简单的可调用类,它允许我们这样做。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()

Iterating through the dataset

让我们把这些放在一起创建一个具有组合转换的数据集。综上所述,每次采样该数据集时:

  • 从文件中动态读取图像
  • 转换应用于读取的图像
  • 由于其中一种转换是随机的,因此数据在抽样时得到了扩充

我们可以用for i in range来遍历创建的数据集,就像以前一样。

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break

Out:

0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

然而,我们通过使用一个简单的循环来迭代数据丢失很多特性。特别是,我们错过了:

  • 批处理数据
  • 移动数据
  • 使用多处理工作者并行加载数据。

torch.utils.data.DataLoader是一个迭代器,它提供所有这些特性。下面使用的参数应该是清楚的。其中一个参数是collate_fn。您可以指定如何使用collate_fn来处理样本。然而,默认的collate应该适用于大多数用例。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
                    landmarks_batch[i, :, 1].numpy(),
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

Out:

0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

Afterword: torchvision

在本教程中,我们已经了解了如何编写和使用数据集、转换和dataloader。torchvision包提供了一些常见的数据集和转换。您甚至可能不需要编写自定义类。在torchvision中可用的一个更通用的数据集是ImageFolder。它假定图像按照以下方式组织

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

其中'蚂蚁','蜜蜂'等是类标签。 类似的通用变换也可以在PIL.Image上运行,如RandomHorizontalFlip,Scale。 您可以使用它们来编写这样的数据加载器:

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

猜你喜欢

转载自blog.csdn.net/u013049912/article/details/88421764