Table of contents
torchvision
Transforms process a single image, and when making a dataset, it is necessary to process images in batches. Therefore, this section is to preprocess the data set with the joint use of datasets
and in torchvision .transforms
- (Torchvision official document address: torchvision — Torchvision 0.15 documentation)
- The functions required for built-in datasets and custom datasets (DatasetFolder, ImageFolder, VisionDataset) are provided in torchvision.datasets (torchvision.datasets official document address: Datasets — Torchvision 0.15 documentation )
torchvision.models
Contains trained neural network models for image classification, image segmentation, and object detection. (The official document address of torchvision.models: Models and pre-trained weights — Torchvision 0.15 documentation)torchvision.transforms
Transform and enhance images (official document address of torchvision.transforms: Transforming and augmenting images — Torchvision 0.15 documentation)torchvision.utils
Contains various utility tools, mainly used for visualization (tensorboard is in torch.utils.tensorboard) (official document address of torchvision.utils: Utils — Torchvision 0.15 documentation)import torchvision from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import transforms # 1. 用transforms设置图片转换方式 data_transform = transforms.Compose([ # 用Compose将所有转换操作集合起来 transforms.ToTensor() # 因为CIFAR10数据集的每张图像size=(32,32)比较小,所以只进行ToTensor的操作 ]) # 2. 加载内置数据集CIFAR10,并设置transforms(download最好一直设置成True) # 1. root:(若要下载的话)表示数据集存放的根目录 # 2. train=True 或者 False,分别表示是构造训练集train_set还是测试集test_set # 3. transform = data_transform,用自定义的data_transform对数据集中的每张图像进行预处理 # 4. download=True 或者 False,分别表示是否从网上下载数据集到root中(如果root下已有数据集,尽管设置成True也不会再下载了,所以download最好一直设置成True) train_set = torchvision.datasets.CIFAR10('./dataset', train=True, transform=data_transform, download=True) test_set = torchvision.datasets.CIFAR10('./dataset', train=False, transform=data_transform, download=True) # 3. 写进tensorboard查看 writer = SummaryWriter('CIFAR10') for i in range(10): img, label = test_set[i] # test_set[i]返回的依次是图像(PIL.Image)和类别(int) writer.add_image('test_set', img, i) writer.close()
DataLoader
Official document address: torch.utils.data.DataLoader
CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
sampler=None, batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
multiprocessing_context=None, generator=None, *, prefetch_factor=2,
persistent_workers=False)
All parameters except dataset
(indicate the location of the dataset) have default values set.
torch.utils.data.DataLoader
The parameters to focus on are:
- dataset (Dataset) : Indicate from which dataset to load data (as defined in the previous section
train_set
) - batch_size (int) : How many samples to load per batch.
- shuffle (bool) : Whether to shuffle the order of samples each round (epoch). (preferably set to True)
- num_workers (int) : How many subprocesses to use for data loading.
0
Indicates that the main process is loaded. - (It can only be set to 0 under Windows, otherwise an error will occur! Although default=0, it is best to manually set num_workers=0)
- drop_last (bool) : If the dataset size is not divisible by batch_size, the batch will be incomplete (i.e. number of samples < batch_size). Set to True to delete the last batch, and False to keep (the default is
False
, that is, the last incomplete batch will be saved) .