Débutants du deep learning, comment télécharger des jeux de données publics communs et les utiliser ?

Cet article a participé à l'événement "Newcomer Creation Ceremony" pour commencer ensemble la route de la création d'or.

1. Introduction

Lorsque j'ai commencé à apprendre en profondeur, il était inévitable d'utiliser certains ensembles de données publics. Maintenant, je n'ai rien à faire et j'enregistre comment télécharger rapidement certains ensembles de données classiques. Apprendre à travers des documents officiels est une méthode souvent recommandée par certaines grosses vaches, nous allons donc commencer à apprendre à partir de documents officiels dans ce blog.

Parce que je fais de la direction de CV, j'utilise la bibliothèque TorchVision comme exemple. Depuis le site officiel :This library is part of the [PyTorch](http://pytorch.org/) project. PyTorch is an open source machine learning framework.

The [torchvision] package consists of popular datasets, model architectures, and common image transformations for computer vision.

Y compris de nombreux ensembles de données populaires, tels que nos CIFAR, COCO et MINST communs, tout le monde devrait être familier. image-20211112215757902Je prendrai l'ICRA comme exemple dans un moment pour consigner mon processus.

2. Comment lire les documents officiels

  1. Voyons d'abord CIFARla documentation de cette classe :

    image-20211112220405128

    paramètre:

    root : indique dans quel répertoire placer le jeu de données téléchargé

    root (string): Root directory of dataset where directory ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
    复制代码

    train : s'il s'agit d'un ensemble de données d'entraînement

    train (bool, optional): If True, creates dataset from training set, otherwise creates from test set.
    复制代码

    transform : une fonction qui prétraite l'image et renvoie la transformation

    A function/transform that takes in an PIL image and returns a transformed version.
    复制代码

    download : s'il faut télécharger le jeu de données,

    download (bool, optional):If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
    复制代码

3. Code pratique

  1. exemple de code

    # 导入torchvision包
    import torchvision
    
    # 对原始图像进行数据处理的函数
    dataset_transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    
    # 生成训练数据集和测试数据集
    # 训练数据集 存放在根目录的dataset文件夹下,作为训练数据集,并下载
    train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
    # 测试数据集 存放在根目录的dataset文件夹下,不作为训练数据集,并下载
    test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
    
    print(test_set[0])
    复制代码
  2. Ensuite, nous faisons un clic droit pour exécuter et télécharger

    image-20211113095342812

    On peut voir que l'ensemble de données a été téléchargé, mais parce qu'il est téléchargé à partir de toronto.edu, la vitesse est très lente. Vous apprendre une méthode plus rapide : on termine l'opération, on copie ce lien, on le télécharge avec Thunder, et tout ira bien bientôt. Décompressez ensuite le fichier téléchargé .gzet placez-le dans le datasetrépertoire que nous avons créé :

    image-20211113100111844
  3. Réexécutez, vous pouvez utiliser l'ensemble de données normalement.

    image-20211113100435809

4. Comment visualiser

Je l'ai utilisé tensorboardpour la visualisation. Si vous êtes intéressé, vous pouvez étudier la bibliothèque tensorboard.

import torchvision
from torch.utils.tensorboard import SummaryWriter
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# 返回类型
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()
复制代码

Vous pouvez voir l'image dans votre navigateur :

image-20211113100721223

Erreur : ssl.SSLCertVerificationError : [SSL : CERTIFICATE_VERIFY_FAILED] Échec de la vérification du certificat : le certificat a expiré (_ssl.c:1131)

Si vous rencontrez le même problème lors du téléchargement, vous devez importer ssl :

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
复制代码

Le dernier mot : écrire n'est pas facile, si ça vous plait ou vous aide, pensez à liker + suivre ou favori ~

Je suppose que tu aimes

Origine juejin.im/post/7086664505731579917
conseillé
Classement