The use of common transforms in Pytorch

This experiment practiced the reading of data in pytorch, the use of the Dataset class, and the use of the transform module.

1. Introduction to Pytorch

PyTorch is an open source Python machine learning library, based on Torch, for applications such as natural language processing.

In January 2017, PyTorch was launched by Facebook Artificial Intelligence Research Institute (FAIR) based on Torch. It is a Python-based sustainable computing package that provides two advanced features: 1. Tensor computing (such as NumPy) with powerful GPU acceleration. 2. Deep neural network including automatic derivation system.

Two, Pytorch environment configuration

There are a lot of teaching on the Internet about Pytorch's environment configuration, so I won't go into details here.

Three, the basic use of the Dataset class

Dataset class: process data and provide a way to select data and its corresponding label.

Dataloader class: Packs the data selected by the Dataset, and provides different data forms for the subsequent network.

1. First import the Dataset class

from import Dataset

2. Create a class that inherits the Dataset class

class MyData(Dataset): 

    def __init__ ( self , root_dir , label_dir): 

        self .root_dir = root_dir 

        self .label_dir = label_dir 

        # os.path.join means splicing these two paths 

        # If the root path is dataset\train, The label path is ants, and the result after splicing is dataset\train\\ants 

        self .path = os.path.join( self .root_dir , self .label_dir) 

        # os.listdir(path) 

        #Function : Pass in any path path , returns a list of all files and directories under the path; 

        self .img_path = os.listdir( self .path) 

    #The function of this function is to get each image 

    def __getitem__( self , idx): 

        # idx is the index of the picture, and img_name is to get the picture 

        img_name = self .img_path[idx] 

        #join the path of the picture together 

        img_item_path = os.path.join( self .root_dir , self .label_dir , img_name ) 

        #Open image 

        img = 

        #Need to use label 

        label = self .label_dir 

        #Return label and read image 

        return img , label 

    def __len__ ( self ): 

        #Return how many pictures 

        return len ( self.img_path)

Fourth, the use of common transform

First import the SummaryWriter function, the function of this function is to display the picture in the browser.

writer = SummaryWriter('logs')
img =

1. ToTensor method:

The image type accepted by this class is Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

# ToTensor的使用
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
"TotTensor", img_tensor)

Type in terminal

tensorboard --logdir="logs" --port=6007

Click the link to enter the browser output image as follows


The function of this method is to convert the image to tensor type.

2. Normalize method

The normalization class needs to pass in the mean and standard deviation.

tran_norm = transforms.Normalize([ 0.5 , 0.5 , 0.5 ] , [ 0.5 , 0.5 , 0.5 ])
#Input the image that needs to be normalized
img_norm = tran_norm(img_tensor)
print (img_norm[ 0 ][ 0 ][ 0 ])
writer. add_image(
'Normalize' , img_norm , 2 )

The output is as follows


3. Resize method

i.e. change image size

trans_resize = transforms.Resize((3, 3))
img_resize = trans_resize(img_tensor)
'Resize', img_resize, 0)

The output is as follows


4. Compose method

Compose () usage: the parameter needs to be a list, the data type in the list is transforms, which means to merge the methods of the two classes.

trans_resize_2 = transforms.Resize(512)
trans_compose = transforms.Compose([trans_resize
, trans_resize_2])
img_resize_2 = trans_compose(img_tensor)

'resize', img_resize_2, 1)

The output is as follows


Five, the combination of dataset class and transform

First download the data set, because it is only for practice, so download the smaller CIFAR10 data set.

root is the saved directory, when train=True, download the training set, otherwise download the data set, and convert the downloaded data set to tensor type

train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transformdownload=True)
test_set = torchvision.datasets.CIFAR10(
root='./dataset', train=False, transform=dataset_transform, download=True)

Fetch the first ten pictures of the test set and pass them to the browser

writer = SummaryWriter('logs')
for i in range(10):
, target = test_set[i]
'test_set', img, i)

The output is as follows


2. The use of dataloader

Dataloader class: Packs the data selected by the Dataset, and provides different data forms for the subsequent network.

Prepared test dataset

test_data = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())

batch_size=4 means that 4 data are taken out from the data set for packaging each time

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

Display the packaged image in the browser

step = 0
writer = SummaryWriter('dataloader')
for data in test_loader:
, targets = data
'test_data', imgs, step)
    step = step+


The output is as follows


Guess you like