Machine Learning Fundamentals: Six Ways to Help You Solve the Problem of Model Overfitting

In machine learning, overfitting will make the prediction performance of the model worse, which usually occurs when the model is too complex, such as too many parameters. This paper summarizes overfitting and its solutions. If you like this article, remember to bookmark, like, and follow.

[Note] A technical exchange group is provided at the end of the article

As Buffett said, "It's better to be approximately right than exactly wrong."

picture

In machine learning, a model is said to be overfitting if it is too focused on a specific training data and misses the point. The answer provided by the model is far from the correct answer, i.e. the accuracy is reduced. Such models treat noise in irrelevant data as a signal, negatively impacting accuracy. Even if the model is well trained with a small loss, it doesn't help that it still performs poorly on new data. Underfitting refers to the logic that the model does not capture the data. Therefore, the underfitted model has lower accuracy and higher loss.

picture

How to determine if a model is overfitting?

When building a model, the data is divided into 3 categories: training set, validation set, and test set. The training data is used to train the model; the validation set is used to test the model built at each step; the test set is used to evaluate the model at the end. Usually data is distributed in a ratio of 80:10:10 or 70:20:10.

In the process of building the model, use the validation data to test the currently built model in each epoch, and get the loss and accuracy of the model, as well as the validation loss and validation accuracy of each epoch. After the model is built, use the test data to test the model and get the accuracy. If there is a large difference between the accuracy and the validation accuracy, the model is overfitted.

If the loss on both validation and test sets is high, then the model is underfitted.

How to prevent overfitting

Cross-validation

Cross-validation is a great way to prevent overfitting. In cross-validation, we generate multiple train-test splits and tune the model. K-fold validation is a standard method of cross-validation in which the data is divided into k subsets, one of which is used for validation and the other subsets are used to train the algorithm.

Cross-validation allows tuning of hyperparameters and performance is the average of all values. This method is computationally expensive, but does not waste much data. The cross-validation process is shown in the following figure:

picture

train with more data

Training the model with more relevant data helps to better identify the signal and avoid noise as the signal. Data augmentation is a way to increase training data, which can be achieved by flipping, translation, rotation, scaling, changing brightness, etc.

remove feature

Removing features can reduce the complexity of the model and avoid noise to a certain extent, making the model more efficient. To reduce complexity, we can make the network smaller by removing layers or reducing the number of neurons.

early stop

When training a model iteratively, we can measure the performance of each iteration. When the validation loss starts to increase, we should stop training the model, which prevents overfitting.

The diagram below shows when to stop training the model:

picture

Regularization

Regularization can be used to reduce model complexity. This is done by penalizing the loss function, which can be done in two ways, L1 and L2, the mathematical equation is as follows:

picture

The purpose of the L1 penalty is to optimize the sum of the absolute values ​​of the weights. It produces a simple and interpretable model that is robust to outliers.

picture

The sum of squares of the L2 penalty weight values. The model is able to learn complex data patterns but is not robust to outliers.

Both of these regularization methods help to solve the problem of overfitting, and readers can choose to use them according to their needs.

Dropout

Dropout is a regularization method used to randomly disable neural network units. It can be implemented on any hidden layer or input layer, but not on the output layer. This approach can dispense with dependencies on other neurons, allowing the network to learn independent correlations. This method can reduce the density of the network, as shown in the following figure:

picture

Summarize

Overfitting is a problem that needs to be solved because it prevents us from using the existing data effectively. Sometimes we can also predict overfitting before building the model. Symptoms of overfitting can be found by looking at the data, how it was collected, how it was sampled, wrong assumptions, and misrepresentation. To avoid this, check the data before modeling. But sometimes overfitting is not detected during preprocessing, but only after the model is built. We can solve the overfitting problem using the above method.

recommended article

Technology Exchange

Welcome to reprint, collect, like and support!

insert image description here

At present, a technical exchange group has been opened, and the group has more than 2,000 members . The best way to remark when adding is: source + interest direction, which is convenient to find like-minded friends

  • Method 1. Send the following picture to WeChat, long press to identify, and reply in the background: add group;
  • Method ②, add micro-signal: dkl88191 , note: from CSDN
  • Method ③, WeChat search public account: Python learning and data mining , background reply: add group

long press follow

Guess you like

Origin blog.csdn.net/weixin_38037405/article/details/123890583