PyTorch study notes - how to use Torchvision dataset

The Torchvision official document Torchvisiontorchvision.datasets is the standard data set provided by Torchvision, which has the following contents:

insert image description here

Let's take CIFAR as an example. The data set includes 60,000 images of 32*32 pixels. There are 10 categories in total, and each category has 6,000 images, of which 50,000 images are training images and 10,000 are test images. Its usage instructions are shown in the figure below:

insert image description here

  • root: The path where the dataset is stored.
  • train: If it is True, the created dataset will be the training set, otherwise the created dataset will be the testing set.
  • transform: Transform the dataset using the transform operations transformsin the .
  • target_transform: Transform the target.
  • download: If it is True, the data set will be downloaded from the Internet automatically, otherwise it will not be downloaded.

For example:

import torchvision

train_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=False, download=True)

print(train_set[0])  # (<PIL.Image.Image image mode=RGB size=32x32 at 0x24011FC4F40>, 6)

When you first start running, you can see that the dataset is being downloaded from the Internet. If the download speed is very slow, you can copy the link to download it from a place such as Thunder. After downloading, create a set path by yourself and put the dataset over:

insert image description here

Then set a breakpoint, run the code in Debug mode, let's check the content of the data set:

insert image description here

You can see classesthat represents the type of image, classes_to_idxwhich means that the type is mapped to an integer, and targetsrepresents the type number corresponding to each image. Try to output the information of the first image:

img, target = train_set[0]
print(img)  # <PIL.Image.Image image mode=RGB size=32x32 at 0x1EEAEC32190>
print(target)  # 6
print(train_set.classes[target])  # frog
img.show()  # 图像显示为青蛙

Now to show how to use transformthe parameter , assuming we need to convert the images of the dataset into tensor type:

trans_dataset = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=True, transform=trans_dataset, download=True)
test_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=False, transform=trans_dataset, download=True)

img, target = train_set[0]
print(type(img))  # <class 'torch.Tensor'>

Guess you like

Origin blog.csdn.net/m0_51755720/article/details/128060988