Beginners of deep learning, how to download common public datasets and use them?

This article has participated in the "Newcomer Creation Ceremony" event to start the road of gold creation together.

1 Introduction

When I first started deep learning, it was inevitable to use some public data sets. Now I have nothing to do, and record how to quickly download some classic data sets. Learning through official documents is a method often recommended by some big cows, so we will start learning from official documents in this blog.

Because I am doing CV direction, I use the TorchVision library as an example. From the official website:This library is part of the [PyTorch](http://pytorch.org/) project. PyTorch is an open source machine learning framework.

The [torchvision] package consists of popular datasets, model architectures, and common image transformations for computer vision.

Including many popular datasets, such as our common CIFAR, COCO and MINST, everyone should be familiar. image-20211112215757902I will take CIFAR as an example in a while to record my process.

2. How to read the official documents

  1. First let's take a look at CIFARthe documentation for this class:

    image-20211112220405128

    parameter:

    root: Indicates which directory to place the downloaded dataset in

    root (string): Root directory of dataset where directory ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
    复制代码

    train: whether it is a training dataset

    train (bool, optional): If True, creates dataset from training set, otherwise creates from test set.
    复制代码

    transform: a function that preprocesses the image and returns the transform

    A function/transform that takes in an PIL image and returns a transformed version.
    复制代码

    download: whether to download the dataset,

    download (bool, optional):If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
    复制代码

3. Hands-on code

  1. sample code

    # 导入torchvision包
    import torchvision
    
    # 对原始图像进行数据处理的函数
    dataset_transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    
    # 生成训练数据集和测试数据集
    # 训练数据集 存放在根目录的dataset文件夹下,作为训练数据集,并下载
    train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
    # 测试数据集 存放在根目录的dataset文件夹下,不作为训练数据集,并下载
    test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
    
    print(test_set[0])
    复制代码
  2. Then we right click to run and download

    image-20211113095342812

    It can be seen that the data set has been downloaded, but because it is downloaded from toronto.edu, the speed is very slow. Teach you a faster method: we terminate the operation, copy this link, download it with Thunder, and it will be fine soon. Then unzip the downloaded .gzfile and put it in the datasetdirectory we created:

    image-20211113100111844
  3. Re-run, you can use the data set normally.

    image-20211113100435809

4. How to visualize

I used tensorboardit for visualization. If you are interested, you can study the tensorboard library.

import torchvision
from torch.utils.tensorboard import SummaryWriter
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# 返回类型
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()
复制代码

You can see the image in your browser:

image-20211113100721223

遇到问题:ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1131)

If you encounter the same problem in the download, you need to import ssl:

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
复制代码

The last word : it is not easy to write, if you like or help you, remember to like + follow or favorite ~

Guess you like

Origin juejin.im/post/7086664505731579917