Cet article a participé à l'événement "Newcomer Creation Ceremony" pour commencer ensemble la route de la création d'or.
1. Introduction
Lorsque j'ai commencé à apprendre en profondeur, il était inévitable d'utiliser certains ensembles de données publics. Maintenant, je n'ai rien à faire et j'enregistre comment télécharger rapidement certains ensembles de données classiques. Apprendre à travers des documents officiels est une méthode souvent recommandée par certaines grosses vaches, nous allons donc commencer à apprendre à partir de documents officiels dans ce blog.
Parce que je fais de la direction de CV, j'utilise la bibliothèque TorchVision comme exemple. Depuis le site officiel :This library is part of the [PyTorch](http://pytorch.org/) project. PyTorch is an open source machine learning framework.
The [torchvision] package consists of popular datasets, model architectures, and common image transformations for computer vision.
Y compris de nombreux ensembles de données populaires, tels que nos CIFAR, COCO et MINST communs, tout le monde devrait être familier. Je prendrai l'ICRA comme exemple dans un moment pour consigner mon processus.
2. Comment lire les documents officiels
-
Voyons d'abord
CIFAR
la documentation de cette classe :paramètre:
root : indique dans quel répertoire placer le jeu de données téléchargé
root (string): Root directory of dataset where directory ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 复制代码
train : s'il s'agit d'un ensemble de données d'entraînement
train (bool, optional): If True, creates dataset from training set, otherwise creates from test set. 复制代码
transform : une fonction qui prétraite l'image et renvoie la transformation
A function/transform that takes in an PIL image and returns a transformed version. 复制代码
download : s'il faut télécharger le jeu de données,
download (bool, optional):If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. 复制代码
3. Code pratique
-
exemple de code
# 导入torchvision包 import torchvision # 对原始图像进行数据处理的函数 dataset_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) # 生成训练数据集和测试数据集 # 训练数据集 存放在根目录的dataset文件夹下,作为训练数据集,并下载 train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True) # 测试数据集 存放在根目录的dataset文件夹下,不作为训练数据集,并下载 test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True) print(test_set[0]) 复制代码
-
Ensuite, nous faisons un clic droit pour exécuter et télécharger
On peut voir que l'ensemble de données a été téléchargé, mais parce qu'il est téléchargé à partir de toronto.edu, la vitesse est très lente. Vous apprendre une méthode plus rapide : on termine l'opération, on copie ce lien, on le télécharge avec Thunder, et tout ira bien bientôt. Décompressez ensuite le fichier téléchargé
.gz
et placez-le dans ledataset
répertoire que nous avons créé : -
Réexécutez, vous pouvez utiliser l'ensemble de données normalement.
4. Comment visualiser
Je l'ai utilisé tensorboard
pour la visualisation. Si vous êtes intéressé, vous pouvez étudier la bibliothèque tensorboard.
import torchvision
from torch.utils.tensorboard import SummaryWriter
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 返回类型
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
复制代码
Vous pouvez voir l'image dans votre navigateur :
Erreur : ssl.SSLCertVerificationError : [SSL : CERTIFICATE_VERIFY_FAILED] Échec de la vérification du certificat : le certificat a expiré (_ssl.c:1131)
Si vous rencontrez le même problème lors du téléchargement, vous devez importer ssl :
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
复制代码
Le dernier mot : écrire n'est pas facile, si ça vous plait ou vous aide, pensez à liker + suivre ou favori ~