data loading and processing

解决任何机器学习问题的努力,很多都花费在准备数据上。PyTorch提供了很多工具来让数据载入变得简单,并期望能够使你的代码变得更可读。在本教程中,我们将会了解如何载入和预处理/扩充非琐碎(non trivial)数据集中的数据。

为了运行本教程,请确保以下包已下载:

  • scikit-image: 用于图像的io和修改。
  • pandas: 为了更容易地解析csv。(这里遇到一个问题:在import pandas时提示’No Module named pandas’。但是我看了一下我是安装了这个包的。询问他人后才明白,原来我的pandas包是安装在了python环境下,但并没有在pytorch环境下安装。于是我打开命令管理器,用activate pytorch激活环境后再pip install pandas就解决问题了。)
## from __future__一般用在使用py2的代码里,它使py2的编译器能够使用一些py3的特性
## 这里是抄的教程的代码,但是实测不用它也可以,因为我用的编译器是py3
## 更详细的解释见 https://blog.csdn.net/qq_36306781/article/details/83018179
from __future__ import print_function,division

## 这是python的一个包,用于提供操作系统的各个功能,在本教程中,用于连接文件地址
## 更详细的见 https://blog.csdn.net/weixin_39541558/article/details/79971971
import os

import torch

## 一个专门用于处理数据的包,名称来自panel data 和data analysis,本教程中,用于读取csv
## 更详细的见 https://www.cnblogs.com/misswangxing/p/7903595.html
import pandas as pd

## skimage是一个用于处理数据的开源包
## io顾名思义,使用来处理图片的输入和输出的
## 更详细的见 https://blog.csdn.net/weixin_39549734/article/details/81234606
## transform则用于做图像的处理和缩放
from skimage import io,transform

from torch.utils.data import Dataset,DataLoader
import numpy as np
import matplotlib.pyplot as plt

## transfroms为常见的图像变换,这些变换可以使用Compose()来链接到一起
## 更详细的见 https://pytorch.org/docs/stable/torchvision/transforms.html
## utils只包含两个函数,make_grid()用于制作一个图像网格
## save_image()用于把一个Tensor存储为图像
## 更详细的见 https://pytorch.org/docs/stable/torchvision/utils.html
from torchvision import transforms,utils

## warnings模块用来控制警告。它不仅控制是否发出警告,也控制发出警告的格式
## 这里使用了警告过滤器,ignore值的效果是忽略所有警告
## 更详细的见 http://blog.konghy.cn/2017/12/16/python-warnings/
import warnings
warnings.filterwarnings("ignore")

## ion()的效果是打开交互模式,在交互模式下,imshow()和plot()会直接出图像
## 而不需要再调用show(),而且可以同时打开几个图像窗口
## 更详细的见 https://blog.csdn.net/M_Z_G_Y/article/details/80309446
plt.ion()

我们要处理的数据集是面部姿态。这意味着脸要像下图一样被标注:

在这里插入图片描述
总之,每个面部都被标注了68个标注点。

Note:
从官方教程 https://pytorch.org/tutorials/beginner/data_loading_tutorial.html# 上下载数据集。该数据集实际上是靠对来自imageNet上标记为脸的一些图片应用优秀的dlib姿态估计来生成的。

数据集附带了一个带注释的csv文件:

在这里插入图片描述
我们接下来快速地看一下CSV文件,并得到一个由标注构成的(N,2)数组,其中N为标注的数目:

## 读取csv文件
landmarks_frame = pd.read_csv('C:/Users/majx1/Music/faces/face_landmarks.csv')


n = 65
## iloc()读取第n行0列数据
img_name = landmarks_frame.iloc[n,0]

## 读取第n行1到最后的数据,并转换为array类型
landmarks = landmarks_frame.iloc[n,1:].as_matrix()

## astype()转换数据类型,reshape()转换数据格式,将其变为两列
landmarks = landmarks.astype('float').reshape(-1,2)

## {}和format()相当于以前的%,format()接受的参数个数和顺序可以不定
## 更详细的见 https://www.runoob.com/python/att-string-format.html
print('图像名称: {}'.format(img_name))
print('标注大小: {}'.format(landmarks.shape))
print('前四个标注: {}'.format(landmarks[:4]))

其结果为:
在这里插入图片描述

我们可以写一个简单辅助函数来显示图片及其标注,并用它来显示一个样本。

def show_landmarks(image,landmarks):
    ## imshow()的作用是对图像进行处理并显示其格式,但不负责显示,要调用show()进行显示
    ## 当你的图像非正规,或你想显示独特的颜色图谱,你就需要调用它
    ## 更详细的见 https://blog.csdn.net/wwwlyj123321/article/details/89023570
    plt.imshow(image)
    
    ## 绘制散点图
    plt.scatter(landmarks[:,0],landmarks[:,1],s=20,marker='.',c='r')
    
    ## 暂停绘图一段时间,此时仍可以与图像交互
    ## 更详细的见 https://matplotlib.org/devdocs/api/_as_gen/matplotlib.pyplot.pause.html
    plt.pause(0.001)

## figure()用来显示图像,并设置它的一些参数
## 更详细的见 https://liam.page/2014/09/11/matplotlib-tutorial-zh-cn/
plt.figure()

## imread读出一个由uint8组成的numpy array,其和cv2。imread()的区别在于,它读出来是RGB,而cv2为BGR
## 更详细的见 https://blog.csdn.net/qq_23589775/article/details/81143584
show_landmarks(io.imread(os.path.join('C:/Users/majx1/Music/faces/',img_name)),landmarks)'

plt.show()

其结果为:
在这里插入图片描述
(注:教程这里使用“.”做标注,实测除了“.“外,还可以使用“,” “^” “*”,不过除了使用”,“时标注点会大一些,其它标注和”."并没有什么区别。
在这里插入图片描述)

数据集类别

torch.utils.data.Dataset是一个表示数据集的抽象类。你的自定义数据集应该继承Dataset类并重写以下方法:

  • len: 使len(dataset)可以返回数据集的大小。
  • getitem:来支持索引,使得dataset[i]可以用来获取第i个样本。

让我们为我们的面部标注数据集建立一个数据集类。我们将会用__init__读取csv,用__getitem__来读取图片。这使得读取变得高效,因为所有的图片不是一次性加载到内存中的,而是在需要的时候才被读取。

我们数据集中的样本将会以dict的形式存储:{‘image’:image,‘landmarks’,landmarks}。我们的数据集将会使用一个可选的参数transform,使得任何需要的处理可以被应用到样本上。我们将会在下一节看到transform的好处。

class FaceLandmarksDataset(Dataset):
    def __init__(self,csv_file,root_dir,transform=None):
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.landmarks_frame)
    
    def __getitem__(self,idx):
        ## 获得图片文件名
        img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx,0])
        
        ## 读取文件
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx,1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1,2)
        sample = {'image':image,'landmarks':landmarks}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

接下来初始化该类,并迭代数据样本。我们将会显示前四个样本的大小,并展示它们的标注。

face_dataset = FaceLandmarksDataset(csv_file='C:/Users/majx1/Music/faces/face_landmarks.csv',
                                   root_dir='C:/Users/majx1/Music/faces/')

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]
    
    print(i,sample['image'].shape,sample['landmarks'].shape)
    
    ## 前两个参数为绘制区域显示图片的行列数,第三个参数为当前图片所在的位置(行优先)
    ax = plt.subplot(1,4,i+1)
    
    ## tight_tight()函数会自动调整子图的参数,使其尽可能紧密地排列
    ## 不过这是一个实验特性,在有些时候可能会不起作用——比如在我的环境下
    plt.tight_layout()
    
    ## set_title()设置子图标题
    ax.set_title('Sample #{}'.format(i))
    
    ## 关闭图片上的坐标轴,但是在网上没找到该函数的解释以及是否有其它参数和取值
    ax.axis('off')
    
    ## **和*有很多含义,放在实参前时,**的意思是表示它传递了一个dict,*则是一个元组
    ## 更详细的见 https://blog.csdn.net/yilovexing/article/details/80577510
    show_landmarks(**sample)
    
    if i == 3:
        plt.show()
        break

其结果为:

在这里插入图片描述

图像变换

从上面我们可以看出,样本的大小并不相同。大多数神经网络期望图片的大小固定。因此,我们要写一些处理代码。我们可以创建三种变换:

  • Rescale:缩放图片。
  • RandomCrop:从图像随机裁剪。这是数据增加。
  • ToTensor:将numpy图像转化为torch图像(需要转换数轴,因为这两种图像的数轴并不相同)。

我们将把它们写成可调用的类,而不是简单的函数,从而变换需要的参数不必在每次调用时都被传递。为了达成这个目标,我们只需要简单的实现__call__方法,以及__init__方法——如果需要的话。接下来我们就可以像下面这样进行变换:

tsfm = Transform(param)
transformed_sample = tsfm(sample)

观察下面的代码是如何将变换同时应用在图像和标注上的。

## 继承了object类后,你的类就能够实现很多高级特性,但在python3中无论写不写都会默认继承。
## 更详细的见 https://blog.csdn.net/DeepOscar/article/details/80947155
class Rescale(object):
    def __init__(self,output_size):
        ## assert为断言函数,当条件不符合时引发AssertionError
        ## 更详细的见 https://blog.csdn.net/hunyxv/article/details/52737339
        
        ## isinstance()检查两个参数是否为同类型
        ## 更详细的见 https://www.runoob.com/python/python-func-isinstance.html
        
        assert isinstance(output_size,(int,tuple))
        self.output_size = output_size
    
    ## 在python中,函数的对象的区别并不大,只要在类中定义__call__(),就可以使用实例化的对象调用该函数。
    ## 更详细的见 https://www.cnblogs.com/superxuezhazha/p/5793536.html
    def __call__(self,sample):
        image,landmarks = sample['image'],sample['landmarks']
        
        ## 获得图片的长宽
        h,w = image.shape[:2]
        if isinstance(self.output_size,int):
        
            ## 以outsize为基准缩放图片
            if h > w:
                new_h,new_w = self.output_size * h / w,self.output_size
            else:
                new_h,new_w = self.output_size,self.output_size * w / h
                
        ## 如果outsize为元组
        else:
            new_h,new_w = self.output_size
        
        new_h,new_w = int(new_h),int(new_w)
        
        ## transform.resize(image,output.shape)通过拉伸将左边的图像的大小变为右边
        ## 更详细的见 https://blog.csdn.net/wuguangbin1230/article/details/71107109
        img = transform.resize(image,(new_h,new_w))
        
        landmarks = landmarks * [new_w/w,new_h/h]
        
        return {'image':img,'landmarks':landmarks}
    
class RandomCrop(object):
    def __init__(self,output_size):
        assert isinstance(output_size,(int,tuple))
        if isinstance(output_size,int):
            self.output_size = (output_size,output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
            
    def __call__(self,sample):
        image,landmarks = sample['image'],sample['landmarks']
        
        h,w = image.shape[:2]
        new_h,new_w = self.output_size
        
        top = np.random.randint(0,h - new_h)
        left = np.random.randint(0,w - new_w)
        
        image = image[top:top + new_h,left:left + new_w]
        
        landmarks = landmarks - [left,top]
        
        return {'image':image,'landmarks':landmarks}
    
class ToTensor(object):
    def __call__(self,sample):
        image,landmarks = sample['image'],sample['landmarks']
            
        image = image.transpose((2,0,1))
        ## from_numpy(ndarray)将ndarray转换为Tensor
        ## 更详细的见 http://www.mamicode.com/info-detail-2217311.html
        return {'image':torch.from_numpy(image),
                    'landmarks':torch.from_numpy(landmarks)}

组合变换

现在,我们将这些变换应用到样本上。我们想要把图像较短的部分放缩到256,然后随机地从上面裁切出224大小的方块,即我们要将Rescale和RandomCrop进行组合。torchvision.transforms.Compose是一个能够做到这件事的简单的可调用类。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),RandomCrop(224)])

fig = plt.figure()
sample = face_dataset[65]
## enumerate将元组、列表和字符串等可遍历的数据对象转换为一个索引序列
## 更详细的见 https://www.runoob.com/python/python-func-enumerate.html
for i,tsfrm in enumerate([scale,crop,composed]):
    transformed_sample = tsfrm(sample)
    
    ax = plt.subplot(1,3,i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)
    
plt.show()

其结果为:
在这里插入图片描述

迭代数据集

让我们把这些放在一起,来创建一个有着组合变换的数据集。总而言之,每次采样该数据集时:

  • 即时地从文件中读取图像。
  • 在读取的图像上应用变换。
  • 由于其中一个变换是随机的,因此在采样时会增加数据。

我们可以像之前一样,使用for i in range来对创建的数据集进行迭代操作。

transformed_dataset = FaceLandmarksDataset(csv_file='C:/Users/majx1/Music/faces/face_landmarks.csv',
                                          root_dir='C:/Users/majx1/Music/faces/',
                                          transform=transforms.Compose([
                                              Rescale(256),RandomCrop(224),
                                              ToTensor()
                                          ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]
    
    print(i,sample['image'].size(),sample['landmarks'].size())
    
    if i == 3:
        break

其结果为:

在这里插入图片描述
然而,当使用简单的for循环来对数据进行迭代操作时,我们将不能够使用很多功能。尤其是以下功能:

  • 批量处理数据
  • 对数据进行“洗牌”
  • 使用multiprocessing程序来并行地加载数据。

torch.utils.data.DataLoader是一个可以提供以上功能的迭代器。下面使用的参数应该是清楚的。一个有趣的参数是collate_fn。你可以通过使用collate_fn来指定批量处理样本的准确程度。然而,默认的校勘(collate,也许是指默认值? )在大多数情况下效果很好。

dataloader = DataLoader(transformed_dataset,batch_size=4,shuffle=True,num_workers=4)

def show_landmarks_batch(sample_batched):
    images_batch,landmarks_batch = sample_batched['image'],sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2
    
    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1,2,0)))
    
    for i in range(batch_size):
        plt.scatter(landmarks_batch[i,:,0].numpy() + i * im_size + 
                    (i + 1) * grid_border_size,landmarks_batch[i,:,1].numpy() + 
                   grid_border_size,s=10,marker='',c='r')
        
        plt.title('Batch form dataloader')
        
    for i_batch,sample_batched in enumerate(dataloader):
        print(i_batch,sample_batched['image'].size(),
             sample_batched['landmarks'].size())
        
        if i_batch == 3:
            plt.figure()
            show_landmarks_batch(sample_batched)
            plt.axis('off')
            plt.ioff()
            plt.show()
            break

其结果为:

在这里插入图片描述

在本教程中,我们已经看到了如何写入和使用datasets、transforms和dataloader。torchvision包提供了一些普通的datasets和transforms。你可能不必写一些自定义的类。torchvision中一个更通用的可用数据集为ImageFolder。它假设图片按照下图的方式排列:

D
其中ants、bees为类别标签。类似的像RandomHorizontalFlip和Scale这样能够在PIL.Image上运行的通用变换也是可用的。你可以使用它们来写一个下面这样的dataloader:

import torch
from torchvision import transforms,datasets

data_transform = transforms.Compose([
    transforms.RandomSizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])

## 这里的hymenopters_data/train等同于上图中的root文件夹
hymenotera_dataset = datasets.ImageFolder(root='hymenopters_data/train',
                                          transform=data_transform)

dataset_loader = torch.utils.data.DataLoader(hymenopters_dataset,
                                             batch_size=4,shuffle=True,
                                            num_workers=4)

发布了74 篇原创文章 · 获赞 14 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/JachinMa/article/details/94430174