Pytorch personal learning record summary 04

Table of contents

torchvision

DataLoader


torchvision

Transforms process a single image, and when making a dataset, it is necessary to process images in batches. Therefore, this section is to preprocess the data set with the joint use of datasetsand in torchvision .transforms

  1. (Torchvision official document address: torchvision — Torchvision 0.15 documentation
  2. The functions required for built-in datasets and custom datasets (DatasetFolder, ImageFolder, VisionDataset) are provided in torchvision.datasets (torchvision.datasets official document address: Datasets — Torchvision 0.15 documentation )
  3. torchvision.modelsContains trained neural network models for image classification, image segmentation, and object detection. (The official document address of torchvision.models: Models and pre-trained weights — Torchvision 0.15 documentation
  4. torchvision.transformsTransform and enhance images (official document address of torchvision.transforms: Transforming and augmenting images — Torchvision 0.15 documentation
  5. torchvision.utilsContains various utility tools, mainly used for visualization (tensorboard is in torch.utils.tensorboard) (official document address of torchvision.utils: Utils — Torchvision 0.15 documentation
    import torchvision
    from torch.utils.tensorboard import SummaryWriter
    from torchvision.transforms import transforms
    
    # 1. 用transforms设置图片转换方式
    data_transform = transforms.Compose([  # 用Compose将所有转换操作集合起来
        transforms.ToTensor()  # 因为CIFAR10数据集的每张图像size=(32,32)比较小,所以只进行ToTensor的操作
    ])
    
    # 2. 加载内置数据集CIFAR10,并设置transforms(download最好一直设置成True)
    #   1. root:(若要下载的话)表示数据集存放的根目录
    #   2. train=True 或者 False,分别表示是构造训练集train_set还是测试集test_set
    #   3. transform = data_transform,用自定义的data_transform对数据集中的每张图像进行预处理
    #   4. download=True 或者 False,分别表示是否从网上下载数据集到root中(如果root下已有数据集,尽管设置成True也不会再下载了,所以download最好一直设置成True)
    train_set = torchvision.datasets.CIFAR10('./dataset', train=True, transform=data_transform, download=True)
    test_set = torchvision.datasets.CIFAR10('./dataset', train=False, transform=data_transform, download=True)
    
    # 3. 写进tensorboard查看
    writer = SummaryWriter('CIFAR10')
    for i in range(10):
        img, label = test_set[i]    # test_set[i]返回的依次是图像(PIL.Image)和类别(int)
        writer.add_image('test_set', img, i)
    
    writer.close()
    

    DataLoader

Official document address: torch.utils.data.DataLoader 

CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, 
	sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, 
	pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, 
	multiprocessing_context=None, generator=None, *, prefetch_factor=2, 
	persistent_workers=False)

All parameters except dataset(indicate the location of the dataset) have default values ​​set.

torch.utils.data.DataLoaderThe parameters to focus on are:

  • dataset (Dataset) : Indicate from which dataset to load data (as defined in the previous section train_set)
  • batch_size (int) : How many samples to load per batch.
  • shuffle (bool) : Whether to shuffle the order of samples each round (epoch). (preferably set to True)
  • num_workers (int) : How many subprocesses to use for data loading. 0Indicates that the main process is loaded.
  • (It can only be set to 0 under Windows, otherwise an error will occur! Although default=0, it is best to manually set num_workers=0)
  • drop_last (bool) : If the dataset size is not divisible by batch_size, the batch will be incomplete (i.e. number of samples < batch_size). Set to True to delete the last batch, and False to keep (the default is False, that is, the last incomplete batch will be saved) .

 

Guess you like

Origin blog.csdn.net/timberman666/article/details/131873616