The Torchvision official document Torchvisiontorchvision.datasets
is the standard data set provided by Torchvision, which has the following contents:
Let's take CIFAR as an example. The data set includes 60,000 images of 32*32 pixels. There are 10 categories in total, and each category has 6,000 images, of which 50,000 images are training images and 10,000 are test images. Its usage instructions are shown in the figure below:
root
: The path where the dataset is stored.train
: If it is True, the created dataset will be the training set, otherwise the created dataset will be the testing set.transform
: Transform the dataset using the transform operationstransforms
in the .target_transform
: Transform the target.download
: If it is True, the data set will be downloaded from the Internet automatically, otherwise it will not be downloaded.
For example:
import torchvision
train_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=False, download=True)
print(train_set[0]) # (<PIL.Image.Image image mode=RGB size=32x32 at 0x24011FC4F40>, 6)
When you first start running, you can see that the dataset is being downloaded from the Internet. If the download speed is very slow, you can copy the link to download it from a place such as Thunder. After downloading, create a set path by yourself and put the dataset over:
Then set a breakpoint, run the code in Debug mode, let's check the content of the data set:
You can see classes
that represents the type of image, classes_to_idx
which means that the type is mapped to an integer, and targets
represents the type number corresponding to each image. Try to output the information of the first image:
img, target = train_set[0]
print(img) # <PIL.Image.Image image mode=RGB size=32x32 at 0x1EEAEC32190>
print(target) # 6
print(train_set.classes[target]) # frog
img.show() # 图像显示为青蛙
Now to show how to use transform
the parameter , assuming we need to convert the images of the dataset into tensor type:
trans_dataset = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=True, transform=trans_dataset, download=True)
test_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=False, transform=trans_dataset, download=True)
img, target = train_set[0]
print(type(img)) # <class 'torch.Tensor'>