This experiment practiced the reading of data in pytorch, the use of the Dataset class, and the use of the transform module.
1. Introduction to Pytorch
PyTorch is an open source Python machine learning library, based on Torch, for applications such as natural language processing.
In January 2017, PyTorch was launched by Facebook Artificial Intelligence Research Institute (FAIR) based on Torch. It is a Python-based sustainable computing package that provides two advanced features: 1. Tensor computing (such as NumPy) with powerful GPU acceleration. 2. Deep neural network including automatic derivation system.
Two, Pytorch environment configuration
There are a lot of teaching on the Internet about Pytorch's environment configuration, so I won't go into details here.
Three, the basic use of the Dataset class
Dataset class: process data and provide a way to select data and its corresponding label.
Dataloader class: Packs the data selected by the Dataset, and provides different data forms for the subsequent network.
1. First import the Dataset class
from torch.utils.data import Dataset
2. Create a class that inherits the Dataset class
class MyData(Dataset):
def __init__ ( self , root_dir , label_dir):
self .root_dir = root_dir
self .label_dir = label_dir
# os.path.join means splicing these two paths
# If the root path is dataset\train, The label path is ants, and the result after splicing is dataset\train\\ants
self .path = os.path.join( self .root_dir , self .label_dir)
# os.listdir(path)
#Function : Pass in any path path , returns a list of all files and directories under the path;
self .img_path = os.listdir( self .path)
#The function of this function is to get each image
def __getitem__( self , idx):
# idx is the index of the picture, and img_name is to get the picture
img_name = self .img_path[idx]
#join the path of the picture together
img_item_path = os.path.join( self .root_dir , self .label_dir , img_name )
#Open image
img = Image.open(img_item_path)
#Need to use label
label = self .label_dir
#Return label and read image
return img , label
def __len__ ( self ):
#Return how many pictures
return len ( self.img_path)
Fourth, the use of common transform
First import the SummaryWriter function, the function of this function is to display the picture in the browser.
writer = SummaryWriter('logs')
img = Image.open('images/220927.png').convert('RGB')
1. ToTensor method:
The image type accepted by this class is Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
# ToTensor的使用
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
writer.add_image("TotTensor", img_tensor)
Type in terminal
tensorboard --logdir="logs" --port=6007
Click the link to enter the browser output image as follows
The function of this method is to convert the image to tensor type.
2. Normalize method
The normalization class needs to pass in the mean and standard deviation.
tran_norm = transforms.Normalize([ 0.5 , 0.5 , 0.5 ] , [ 0.5 , 0.5 , 0.5 ])
#Input the image that needs to be normalized
img_norm = tran_norm(img_tensor)
print (img_norm[ 0 ][ 0 ][ 0 ])
writer. add_image( 'Normalize' , img_norm , 2 )
The output is as follows
3. Resize method
i.e. change image size
trans_resize = transforms.Resize((3, 3))
img_resize = trans_resize(img_tensor)
print(img_resize)
writer.add_image('Resize', img_resize, 0)
The output is as follows
4. Compose method
Compose () usage: the parameter needs to be a list, the data type in the list is transforms, which means to merge the methods of the two classes.
trans_resize_2 = transforms.Resize(512)
trans_compose = transforms.Compose([trans_resize, trans_resize_2])
img_resize_2 = trans_compose(img_tensor)
print(img_resize_2)
writer.add_image('resize', img_resize_2, 1)
The output is as follows
Five, the combination of dataset class and transform
First download the data set, because it is only for practice, so download the smaller CIFAR10 data set.
root is the saved directory, when train=True, download the training set, otherwise download the data set, and convert the downloaded data set to tensor type
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True)
Fetch the first ten pictures of the test set and pass them to the browser
writer = SummaryWriter('logs')
for i in range(10):
img, target = test_set[i]
writer.add_image('test_set', img, i)
writer.close()
The output is as follows
2. The use of dataloader
Dataloader class: Packs the data selected by the Dataset, and provides different data forms for the subsequent network.
Prepared test dataset
test_data = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())
batch_size=4 means that 4 data are taken out from the data set for packaging each time
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
Display the packaged image in the browser
step = 0
writer = SummaryWriter('dataloader')
for data in test_loader:
imgs, targets = data
writer.add_images('test_data', imgs, step)
step = step+1
writer.close()
The output is as follows