基于PyTorch的卷积神经网络图像分类——猫狗大战(一):使用Pytorch定义DataLoader

目录

1. 需要用到的库

2. 数据扩充定义

3. 自定义Dataset

4. 测试


         开始一个新的系列,基于Kaggle比赛的猫狗大战数据集,基于PyTorch实现猫狗图像分类。

         数据集地址在:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/overview

         下面是第一部分,主要介绍如何使用Pytorch自定义Dataloader。

1. 需要用到的库

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

2. 数据扩充定义

image_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

数据扩充主要分为以下几步:

1)将图像的短边resize到256;

2)然后随即裁减224x224;

3)再进行随机水平翻转;

4)最后将图像转为Tensor并且标准化。

3. 自定义Dataset

class DogVsCatDataset(Dataset):
    """Dog vs Cat dataset."""

    def __init__(self, root_dir, train=True, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.img_path = os.listdir(self.root_dir)
        if train:
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path))
        else:
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))
        self.transform = transform

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))
        label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array([label]))
        return image, label

数据集初始化时要设置图片目录;是否是训练集或者是验证集,图片编号小于10000的为训练集,大于等于10000的为验证集;及数据扩充方式;猫的标签为0,狗的标签为1。

4. 测试

if __name__ == '__main__':
    catanddog_dataset = DogVsCatDataset(root_dir='../dogs-vs-cats-redux-kernels-edition/train', train=False,
                                        transform=image_transform)
    train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)
    image, label = iter(train_loader).next()
    sample = image[0].squeeze()
    sample = sample.permute((1, 2, 0)).numpy()
    sample *= [0.229, 0.224, 0.225]
    sample += [0.485, 0.456, 0.406]
    plt.imshow(sample)
    plt.show()
    print('Label is: {}'.format(label[0].numpy()))

测试的时候使用“if __name__ == '__main__':”可以在其他文件import时,不执行这些语句。执行代码后,显示的图片和打印的标签如下所示:

Label is: [0]

Label is: [1]

发布了20 篇原创文章 · 获赞 6 · 访问量 2173

猜你喜欢

转载自blog.csdn.net/linghu8812/article/details/100044971
今日推荐