PyTorch advanced training skills and UNet demo

Datawhale202210 - "PyTorch in simple terms" (6)



foreword

The training model often needs to be flexibly adjusted according to the experimental situation. At the same time, it is necessary to make full use of the library provided by PyTorch and better serve its own projects. Therefore, it is necessary to further strengthen the training skills. This will be a long-term journey, starting from the first stop today Bar!


1. Custom loss function

defined as a function

The essence of the loss function is " to perform a function operation on the input and get an output ", so it can be defined directly by defining a function, but it is not commonly used.

def my_loss(output, target):
    loss = torch.mean((output - target)**2)
    return loss

defined as a class

This method treats the loss function as a layer of the neural network, which is also inherited from the nn.Module class.

Dice Loss is a common loss function in the field of segmentation, defined as follows:

insert image description here

class DiceLoss(nn.Module):
    def __init__(self,weight=None,size_average=True):
        super(DiceLoss,self).__init__()
    def forward(self,inputs,targets,smooth=1):
        inputs = F.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()                   
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        return 1 - dice

# 使用方法    
criterion = DiceLoss()
loss = criterion(input,targets)

Relevant knowledge supplement : nn.module and nn.function
nn.Module is a packaged class that specifically defines a network layer that can maintain state and store parameter information; while nn.functional only provides a calculation and does not maintain state information and storage parameters. Inside the Module class, the function of the layer is actually realized through nn.functional.
(Of course, there are many common loss functions, and a special summary can be made later)

2. Dynamically adjust the learning rate

If the learning rate is set too small, it will greatly reduce the convergence speed and increase the training time; if the learning rate is too large, the parameters may oscillate back and forth on both sides of the optimal solution.

Use the official scheduler

PyTorch has encapsulated the method of dynamically adjusting the learning rate in torch.optim.lr_scheduler. The specific scheduler is listed as follows:
lr_scheduler.LambdaLR
lr_scheduler.MultiplicativeLR
lr_scheduler.StepLR
lr_scheduler.MultiStepLR
lr_scheduler.ExponentialLR
lr_scheduler.CosineAnneal
ingLR lr_scheduler.ReduceLROnPlateau
lr_scheduler.CyclicLR
lr_scheduler.OneCycleLR
lr_scheduler .CosineAnnealingWarmRestarts

Official user tutorial instructions and explanations

# 选择一种优化器
optimizer = torch.optim.Adam(...) 
# 选择上面提到的一种或多种动态调整学习率的方法
scheduler1 = torch.optim.lr_scheduler.... 
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# 进行训练
for epoch in range(100):
    train(...)
    validate(...)
    optimizer.step()
    # 需要在优化器参数更新之后再动态调整学习率
	scheduler1.step() 
	...
    schedulern.step()

custom scheduler

Although PyTorch officially provides a rich API, in order to cope with complicated experimental projects, it is still necessary to define the learning rate adjustment strategy by yourself. The example is as follows:

I vowed to reduce the learning rate from 30 to the original 1 or 10. Assuming that the official API cannot meet this requirement, the following definition is required:

def adjust_learning_rate(optimizer, epoch):
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

During the training process, calling the adjust_learning_rate function can realize the dynamic change of the learning rate.

def adjust_learning_rate(optimizer,...):
    ...
optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):
    train(...)
    validate(...)
    adjust_learning_rate(optimizer,epoch)

3. Model fine-tuning - torch vision, timm

torchvision

Fine-tuning process (as shown in the figure)

insert image description here

use existing structure

Taking common models in torchvision as an example, it lists how to use common model structures and parameters provided by PyTorch in image classification tasks.

instantiated network

import torchvision.models as models
resnet18 = models.resnet18()
# resnet18 = models.resnet18(pretrained=False)  等价于与上面的表达式
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()

Pass pretrained parameters

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)

Precautions

1. The extension of the PyTorch model is usually .pt/.pth. When the program is running, it will first check whether there is a model weight that has been downloaded in the default path. Once downloaded, it does not need to be downloaded again next time.
2. In order to improve the download speed of the model, you can download it manually: go here to view the module_urls of the model, and download it yourself. The default save path is as follows:
Linux, mac: .cache folder in the user root directory
Windows: C:\Users<username >.cache\torch\hub\checkpoint
torch.utils.model_zoo.load_url() Set the download address of the weight
3. You can also download your own weight and put it in the same folder, and then load the parameters to the network.

self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))

4. If you forcibly stop the download midway, you must go to the corresponding path to delete the weight file, or you may report an error.

train a specific layer

1. Set requires_grad = False to freeze some layers.

Such a routine is provided in PyTorch official.

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

2. Modify the fully connected layer of the model output part

import torchvision.models as models
# 冻结参数的梯度
feature_extract = True
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# 修改模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)

3. Afterwards, during the training process, the model will still perform gradient return, but the parameter update will only occur in the fc layer.

timm (understand by yourself)

Note: xxx not defined often appears when running on jupyter notebook. It may be because the previous block has not been run. Rerunning all of them may solve the problem.

4. Semi-precision training

Half-precision training is to improve the efficiency of the graphics card (to be precise, by reducing the memory usage, so that it can load more data)
insert image description here

Setting up half precision training in PyTorch

import autocast

from torch.cuda.amp import autocast

Model setup: using python's decorator methods

@autocast()   
def forward(self, x):
    ...
    return x

Training process: Put the data into the model and the part after it into "with autocast():".

for x in train_loader:
	x = x.cuda()
	with autocast():
        output = model(x)
        ...

expand

1. Another way to set half precision: GradScaler

GradScaler is the gradient scaler module, which needs to instantiate a GradScaler object before training begins.
So the classic AMP usage in PyTorch is as follows:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

2. The pit encountered - how to solve the problem of nan


When calculating the loss, if the loss is too large when divided by 0 , it is judged by half-precision
that there is nan in the inf network parameter, then the operation result will also output nan

5. Data Enhancement-imgaug

Installation of imgaug

conda

conda config --add channels conda-forge
conda install imgaug

pip

#  install imgaug either via pypi

pip install imgaug

#  install the latest version directly from github

pip install git+https://github.com/aleju/imgaug.git

The use of imgaug in PyTorch

There is no fixed template here, and it is recommended that you find answers based on questions during use.

6. Use argparse to adjust parameters

Introduction to argparse

argsparse is a standard module for python's command line parsing. It is built into python and does not need to be installed. It can parse, save and use other parameters passed in from the command line.

Use of argparse

  • Create an ArgumentParser() object
  • Call the add_argument() method to add parameters
  • Use parse_args() to parse arguments
# demo.py
import argparse

# 创建ArgumentParser()对象
parser = argparse.ArgumentParser()

# 添加参数
parser.add_argument('-o', '--output', action='store_true', 
    help="shows output")
# action = `store_true` 会将output参数记录为True
# type 规定了参数的格式
# default 规定了默认值
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3') 

parser.add_argument('--batch_size', type=int, required=True, help='input batch size')  
# 使用parse_args()解析函数
args = parser.parse_args()

if args.output:
    print("This is some output")
    print(f"learning rate:{args.lr} ")

Enter at the command line

python demo.py --lr 3e-4 --batch_size 32

output

This is some output
learning rate: 3e-4

Advanced argparse

1. In order to make the code more concise and modular, you can write operations related to hyperparameters in config.py, and then import them in train.py or other files. (from ZhikangNiu)
2. (to be continued)
...

7. Reference documents

Feeding from Datawhale

Online tutorial link:
https://datawhalechina.github.io/thorough-pytorch/
Github online tutorial:
https://github.com/datawhalechina/thorough-pytorch
Gitee online tutorial:
https://gitee.com/datawhalechina/thorough -Video of pytorch
station b:
https://www.bilibili.com/video/BV1L44y1472Z
(Welcome everyone with one click + follow!)

Feeding from the official

Github link:
https://github.com/rwightman/pytorch-image-models
official website link:
https://fastai.github.io/timmdocs/ https://rwightman.github.io/pytorch-image-models/

Feeds from netizens

Detailed explanation of the loss function of pytorch tutorial
https://blog.csdn.net/qq_27825451/article/details/95165265
The method to solve the problem of nan in half-precision training
https://github.com/huggingface/transformers/issues/4287
imgaug in PyTorch Use
https://github.com/aleju/imgaug/issues/406
Detailed explanation of usage examples of argparse module
https://zhuanlan.zhihu.com/p/56922793

Summarize

1. In this chapter, I realized more about the ingenuity of our tutorial setting, which often allows readers to understand the cause and effect of a certain method's appearance or upgrade, which is clear and clear. Like our project contributors thanks.
2. Many of the contents involved in the tutorial can only show the core part due to time constraints. In fact, various problems will still appear in the practice process, and you need to get used to collecting and organizing reference materials.
3. The study of this chapter gave me a little understanding of the significance of basic computer learning. For example, an in-depth understanding of hardware can sometimes help us better learn algorithms.
4. Thanks again to the friends of Datawhale, especially Dr. Dennis and Qiqi in the team.

Guess you like

Origin blog.csdn.net/weixin_50967907/article/details/127410671