Pytorch framework learning (6) training a simple own CNN (3) details

Since the last two articles were followed by the video tutorial, many details were overlooked. This article focuses on learning and recording the details of each place.


Detail 1 : torchvision.transforms.Compose
In the process of making the dataset, we used transforms.Compose to preprocess the data.
torchvision.transforms is an image preprocessing package in pytorch. Generally, Compose is used to integrate multiple steps together: such as:

transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(224),#从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    ])

insert image description here
Among them, ToTensor is to convert the image from HWC to CHW and from 0 ~ 255 to 0 ~ 1


Detail 2 : torchvision.datasets.ImageFolder
ImageFolder is a general-purpose data loader, which requires the training, verification or test images of the dataset to be organized in this format of each folder.

#root 为data/train或者data/valid
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
 
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

Detailed parameter explanation:

dataset=torchvision.datasets.ImageFolder(
                       root,  #图片储存的根目录,即各类别文件夹所在目录的上一级目录
                       transform=None,#预处理操作函数(细节1所定义的函数)
                       target_transform=None, # 对图片类别进行预处理的操作,输入为 target,输出对其的转换。 如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
                       loader=<function default_loader>, # 表示数据集加载方式,通常默认加载方式即可。
                       is_valid_file=None) #获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)

The returned dataset has the following three properties:

self.classes: use a list to save the class name
self.class_to_idx: the index corresponding to the class, which corresponds to the target returned without any conversion
self.imgs: save the list of (img-path, class) tuple

insert image description here


Detail 3 : torch.utils.data.DataLoader
torch.utils.data.DataLoader mainly divides the data into batches.

torch.utils.data.DataLoader(image_datasets, # 数据,要求是dataset类型
							batch_size=8, # 批量大小
							shuffle=True, #是否进行数据洗牌)

Other operations can be seen on blogs written by others


Detail 4 : model.parameters() and model.state_dict()
model.parameters() and model.state_dict() are methods used to view network parameters in Pytorch.
Generally speaking, the former is more common in the initialization of the optimizer; the latter is more common in the preservation of the model.
like:

optomizer = torch.optim.Adam(model.parameters(), lr=1e-5)
torch.save(model.state_dict(), ‘best_model,pth’)

model.state_dict() returns an OrderDict, which stores the name of the network structure and the corresponding parameters.
model.state_dict() obtains all learnable parameters (weight, bias) in the model, and also obtains non-learnable parameters (running mean and running var of BN layer, etc.). You can think of model.state_dict() as adding all the non-learnable parameters on top of the model.parameters() function.
————————————————
Copyright statement: This article is an original article of CSDN blogger "yaoyz105", following the CC 4.0 BY-SA copyright agreement, please attach the original source link and this statement for reprinting .
Original link: https://blog.csdn.net/qq_31347869/article/details/125065271


Extra knowledge point 1 : the difference between deep copy, shallow copy and assignment

Python provides three assignment methods, the most common are assignment =, shallow copy copy.copy() and deep copy copy.deepcopy()

Assignment : The assignment in Python is to pass the reference of the object, that is, the transfer of the memory address.
Shallow copy : shallow copy only copies the object itself, and does not copy the nested objects inside the object.
Assignment : A deep copy copies the object itself and all its nested objects.
For details, please refer to this blog

  • model.state_dict() is also a shallow copy. If you set param=model.state_dict(), then when you modify param, the parameters of the model will be modified accordingly.

Detail 5 : torch.load() and model.load_state_dict()
torch.load("path path") means to load the trained model, this model is a state_dict
model.load_state_dict() means to reload the trained model parameters into the network model

# save
torch.save(model.state_dict(), PATH)
 
# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

Additional knowledge point 2 : Image.resize() and Image.thumbnail()

The **Image.resize()** function is used to modify the size of the image. > The **Image.thumbnail()** function is used to make a thumbnail of the current image.

For details, please refer to this blog post~


Guess you like

Origin blog.csdn.net/vibration_xu/article/details/126176197