Datawhale Zero Foundation Entry CV Competition-Task04 Model Training and Verification

In the previous chapter, we built a simple CNN for training, and visualized the error loss during training and the accuracy of the first character prediction, but these are far from enough. A mature and qualified deep learning training process has at least the following functions:

  • Train on the training set and verify on the validation set;
  • The model can save the optimal weight and read the weight;
  • Record the accuracy of the training set and the validation set to facilitate parameter adjustment.

4 Model training and verification

To this end, this chapter will explain the construction of the verification set, model training and verification, model saving and loading, and model tuning. In some sections, the Pytorch code will be combined.

4.1 Methods to prevent overfitting

In the training process of machine learning models (especially deep learning models), the model is very easy to overfit. The training error of the deep learning model will gradually decrease in the continuous training process, but the trend of the test error is not necessarily.
In the training process of the model, the model can only use training data for training, and the model cannot touch the samples on the test set. Therefore, if the model learns the training set too well, the model will remember the details of the training sample, resulting in poor generalization of the model in the test set. This phenomenon is called overfitting. Corresponding to overfitting is underfitting, that is, the model has a poor fit on the training set.
Insert picture description here
As shown in the figure: as the model complexity and the number of model training rounds increase, the error of the CNN model on the training set will decrease, but the error on the test set will gradually decrease, and then gradually increase, and we are pursuing The higher the accuracy of the model on the test set, the better.
There are many reasons for the over-fitting of the model. The most common one is that the Model Complexity is too high, which causes the model to learn all aspects of the training data and learn some subtle rules. The methods to solve over-fitting mainly include data enhancement, weight attenuation, early stopping method, and dropout.

  • Data Augmentation

Generally, to obtain a better model, a large number of training parameters are needed. This is one of the reasons why the CNN network is getting deeper and deeper. If the training samples lack diversity, then no amount of training parameters is meaningless, because With overfitting, the generalization ability of the trained model will be correspondingly poor. The feature diversity brought by a large amount of data helps to make full use of all training parameters. Data enhancement methods generally include: 1) Collect more data; 2) Randomly rotate, flip, crop, set image brightness and contrast, and standardize data on existing data; 3) Use generative models (such as GAN) ) Generate some data.
In common data amplification methods, the image color, size, shape, space, and pixels are generally transformed. Of course, different data amplification methods can be freely combined to obtain more abundant data amplification methods.
Take torchvision.transforms as an example. First, understand the methods of data amplification as a whole, including:

中心裁剪:transforms.CenterCrop
随机裁剪:transforms.RandomCrop
随机长宽比裁剪:transforms.RandomResizedCrop
上下左右中心裁剪:transforms.FiveCrop
上下左右中心裁剪后翻转: transforms.TenCrop
依概率p水平翻转:transforms.RandomHorizontalFlip(p=0.5)
依概率p垂直翻转:transforms.RandomVerticalFlip(p=0.5)
随机旋转:transforms.RandomRotation
对图像进行随机遮挡: transforms.RandomErasing 
尺寸变换:transforms.Resize
标准化:transforms.Normalize
填充:transforms.Pad
修改亮度、对比度和饱和度:transforms.ColorJitter
转灰度图:transforms.Grayscale
依概率p转为灰度图:transforms.RandomGrayscale
线性变换:transforms.LinearTransformation()
仿射变换:transforms.RandomAffine
将数据转换为PILImage:transforms.ToPILImage
转为tensor,并归一化至[0-1]:transforms.ToTensor
用户自定义方法:transforms.Lambda
transforms.RandomChoice(transforms): 从给定的一系列transforms中选一个进行操作
transforms.RandomApply(transforms, p=0.5): 给一个transform加上概率,依概率进行操作
transforms.RandomOrder: 将transforms中的操作随机打乱
  • Weight Decay

Commonly used weight attenuation has L1 and L2 regularization, L1 can obtain more sparse parameters than L2, but the L1 zero point is not divergent. In the loss function, the weight attenuation is a coefficient placed in front of the regularization term. The regularization term generally indicates the complexity of the model, so the function of weight attenuation is to adjust the influence of the model complexity on the loss function. If the weight attenuation is large , Then the value of the complex model loss function is also large.

  • Early Stopping (Early Stopping)

The early stopping method is actually another regularization method, which is to calculate the respective error rate after one iteration on the training set and the validation set. When the error rate on the validation set is the smallest, stop training before starting to increase. Because if you continue training, the error rate on the training set will generally continue to decrease, but the error rate on the validation set will increase, which means that the generalization ability of the model has begun to deteriorate, and overfitting problems occur, so stop in time A better generalized model can be obtained. As shown in the figure below (left is the error rate of the training set, and the right figure is the error rate of the validation set. The training is ended early at the dotted line):
Insert picture description here

  • Dropout

The use of dropout during CNN training is to randomly reset the weight of some neurons to 0 during each training process, that is, to make some neurons fail, which can reduce the amount of parameters and avoid overfitting. There are two views on why dropout is effective: 1. Randomly disables some neurons in each iteration, which increases the diversity of the model, and achieves an effect similar to the integration of multiple models, avoiding overfitting. 2. Dropout is actually a process of data enhancement, which leads to sparsity and makes the differences of local data clusters more obvious, which is why it can prevent overfitting.
Insert picture description here

  • Construct a uniformly distributed validation set

The best solution to the over-fitting problem: construct a sample set that is as consistent as possible with the test set (can be called the validation set), and continuously verify the accuracy of the model on the validation set during the training process, and use this to control the model Training.

4.2 Constructing a validation set

In general, contestants can divide a verification set locally for local verification. The training set, validation set and test set have different functions:

  • Train Set: The model is used to train and adjust model parameters;
  • Validation Set: used to verify model accuracy and adjust model hyperparameters;
  • Test Set: Verify the generalization ability of the model.

Because the training set and the validation set are separate, the accuracy of the model on the validation set can reflect the generalization ability of the model to a certain extent. When dividing the verification set, it is necessary to pay attention to the distribution of the verification set and the test set to be as consistent as possible, otherwise the accuracy of the model on the verification set will lose its guiding significance.
Since the verification set is so important, how to divide the local verification set. In some competitions, the contestant will give a verification set; if the contestant does not give a verification set, then the contestant needs to split a part of the training set to get the verification set. There are several ways to divide the verification set:
Insert picture description here

  • Hold-Out

Directly divide the training set into two parts, a new training set and a validation set. The advantage of this division method is the most direct and simple; the disadvantage is that only one validation set is obtained, which may cause the model to overfit on the validation set. The application scenario of the set aside method is a situation where the amount of data is relatively large.

  • Cross Validation (CV)

Divide the training set into K parts, take K-1 part as the training set, and use the remaining 1 part as the validation set, and loop K training. This division method is that all training sets are validation sets, and the final model verification accuracy is obtained by an average of K copies. The advantage of this method is that the accuracy of the verification set is relatively reliable, and K times can be trained to obtain K models with diverse differences; the disadvantage of CV verification is that it needs to be trained K times, which is not suitable for situations with a large amount of data. The following figure is a schematic diagram of 10-fold cross-validation:

Insert picture description here

  • Self-sampling method (BootStrap)

The new training set and validation set are obtained through the sampling method with replacement. Each training set and validation set are different. This division method is generally applicable to situations where the amount of data is small. In machine learning, the Bootstrapping method allows a model or algorithm to better understand the bias, variance, and features that exist in it. Allow resampling to include different biases, and then include them as a whole. As shown in the figure below, each sample group has different parts, and they are all different. This will affect the overall mean, standard deviation, and other descriptive indicators of the data set. In turn, it can develop more robust models.
Insert picture description here
In this competition, it has been divided into verification sets, so players can directly use the training set for training, and use the verification set to verify accuracy (of course, you can also combine the training set and the verification set to divide the verification set by yourself).
Of course, these division methods are from the perspective of the data division method. The division methods generally used in the existing data competition are the leave-out method and the cross-validation method. If the amount of data is relatively large, the retention method is more appropriate. Of course, the validation set obtained by the division of any validation set must ensure that the distribution of training set-validation set-testing set is consistent, so no matter what kind of division method is divided, it is necessary to pay attention.
The distribution here generally refers to the statistical distribution related to the label. For example, in the classification task, "distribution" refers to the category distribution of the label. The category distribution of the training set-validation set-test set should be roughly the same; if the label is with With timing information, the time interval between the verification set and the test set should be consistent.

4.3 Model training and verification

Here we aim to use Pytorch to complete the CNN training and verification process, and the CNN network structure is consistent with the previous chapters. The logical structure we need to complete is as follows:

  • Construct training set and validation set;
  • Train and verify each round, and save the model according to the best verification set accuracy.
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=10, 
    shuffle=True, 
    num_workers=10, 
)   

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=10, 
    shuffle=False, 
    num_workers=10, 
)

model = SVHN_Model1()
criterion = nn.CrossEntropyLoss (size_average=False)
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0

for epoch in range(20):
    print('Epoch: ', epoch)
    train(train_loader, model, criterion, optimizer, epoch)
    val_loss = validate(val_loader, model, criterion)   
    # 记录下验证集精度
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), './model.pt')

The training code for each Epoch is as follows:

def train(train_loader, model, criterion, optimizer, epoch):
    # 切换模型为训练模式
    model.train()
    
    for i, (input, target) in enumerate(train_loader):
        c0, c1, c2, c3, c4, c5 = model(data[0])
        loss = criterion(c0, data[1][:, 0]) + \
                criterion(c1, data[1][:, 1]) + \
                criterion(c2, data[1][:, 2]) + \
                criterion(c3, data[1][:, 3]) + \
                criterion(c4, data[1][:, 4]) + \
                criterion(c5, data[1][:, 5])
        loss /= 6
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

The verification code of each Epoch is as follows:

def validate(val_loader, model, criterion):
    # 切换模型为预测模型
    model.eval()
    val_loss = []
    
    # 不记录模型梯度信息
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            c0, c1, c2, c3, c4, c5 = model(data[0])
            loss = criterion(c0, data[1][:, 0]) + \
                    criterion(c1, data[1][:, 1]) + \
                    criterion(c2, data[1][:, 2]) + \
                    criterion(c3, data[1][:, 3]) + \
                    criterion(c4, data[1][:, 4]) + \
                    criterion(c5, data[1][:, 5])
            loss /= 6
            val_loss.append(loss.item())
    return np.mean(val_loss)

Model saving and loading
The saving and loading of the model in Pytorch is very simple. The more common way is to save and load the model parameters:
torch.save(model_object.state_dict(), 'model.pt')
model.load_state_dict(torch.load(' model.pt'))

4.4 Model tuning process

Deep learning has few principles but is very practical. Basically, many models can only be verified through training. At the same time, deep learning has many network structures and hyperparameters, so it needs to be tried repeatedly. Training deep learning models requires GPU hardware support and more training time. How to effectively train deep learning models has gradually become a subject.
There are many training techniques for deep learning. The recommended reading links are:

  • http://karpathy.github.io/2019/04/25/recipe/

Insert picture description here

  • pytorch optimizer tuning

torch.optim is a package that implements various optimization algorithms. The most commonly used methods are already supported, and the interface is very conventional, so you can easily integrate more complex methods in the future.

optim.SGD:随机梯度下降法
optim.Adagrad:自适应学习率梯度下降法
optim.RMSprop:Adagrad的改进
optim.Adadelta:Adagrad的改进
optim.Adam:RMSprop结合Momentum
optim.Adamax:Adam增加学习率上限
optim.SparseAdam:稀疏版Adam
optim.ASGD:随机平均梯度下降
optim.Rprop:弹性反向传播
optim.LBFGS:BFGS的改进
  • Six learning rate adjustment strategies of pytorch

Orderly adjustment: Step, MultiStep, Exponential and CosineAnnealing
Adaptive adjustment: ReduceLROnPleateau
Custom adjustment: Lambda
learning rate initialization: Set the smaller number: 0.01 , 0.001, 0.0001

# 等间隔调整学习率
torch.optim.lr_scheduler.StepLR(optimizer,step_size,gamma=0.1)
# 按给定间隔调整学习率
optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1)
# 按指数衰减调整学习率
optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
# 余弦周期调整学习率
optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=0.)
# 监控指标,当指标不再变化则调整
optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, mode=mode, patience=patience, cooldown=cooldown, min_lr=min_lr, verbose=verbose)
# 自定义调整策略
torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])

reference

Computer vision practice (street view character encoding recognition)
datawhalechina
hands-on deep learning

Datawhale is an open source organization focusing on data science and AI. It brings together excellent learners from many universities and well-known companies in many fields, and brings together a group of team members with open source spirit and exploratory spirit. With the vision of "for the learner, grow with learners", Datawhale encourages true self-expression, openness and tolerance, mutual trust and mutual assistance, the courage to try and make mistakes, and the courage to take responsibility. At the same time, Datawhale uses the concept of open source to explore open source content, open source learning and open source solutions, empower talent training, help talent growth, and establish a connection between people and people, people and knowledge, people and enterprises, and people and the future.

Guess you like

Origin blog.csdn.net/OuDiShenmiss/article/details/106446702