foreword
- This article focuses on the basic knowledge and skills of data analysis, and explores the complete process from data exploration to modeling to model interpretation
- The content includes data exploration, model building, parameter tuning skills, SHAP model explanation
- The data comes from the kaggle platform, crab age prediction data set, data details
the data shows
data background
Crabs are delicious, and many countries in the world import large quantities of crabs for consumption every year. The main advantages of crab farming are low labor costs, relatively low production costs, and fast growth. Commercial crab farming is developing the lifestyle of people in coastal areas. With proper care and management, we can earn more money from crab farming than from shrimp farming. You can keep mud crabs in two systems. Develop farming and fattening systems.
data value
For commercial crab farmers, knowing the correct age for crabs helps them decide if and when to harvest crabs. Crabs grow negligibly in physical characteristics beyond a certain age, so it is important to time harvest to reduce costs and increase profits. The goals of this dataset are:
- Exploratory Data Analysis - See how different physical characteristics change with age.
- Feature Engineering − Define new features using a given combination of data points to help improve model accuracy.
- Regression Model - Build a regression model to predict crab age.
data field
Sex
: Crab sex - male (M), female (F) and indeterminate (I).Length
: length of the crab (in feet; 1 foot = 30.48 cm)Diameter
: diameter of the crab in feet; 1 foot = 30.48 cm)Height
: height of the crab (in feet; 1 foot = 30.48 cm)Weight
: the weight of the crab in ounces; 1 pound = 16 ouncesShucked Weight
: Weight without shell (in ounces; 1 lb = 16 oz)Viscera Weight
: Weight of abdominal organs deep in the body (in ounces; 1 pound = 16 ounces)Shell Weight
: Case Weight (oz; 1 lb = 16 oz)Age
: crab age (months)
dependent package
pandas
: read data, base packageydata-profiling
: Quick data exploration package, Github project address , official documentssklearn
: Classic machine learning model package, here is not too much introductionshap
: SHAP (SHapley Additive exPlanations) is a game-theoretic approach to explaining the output of any machine learning model. It links optimal credit assignments to local interpretations, using classical Shapley values and their related extensions from game theory (see the paper for details and citations ).- The following analysis is based on the support of the above packages. Please use it in advance to
pip install
install it. If it is used in Jupyter Notebook, please use ```!pip install``.
Import necessary packages
import numpy as np
import pandas as pd
from plotnine import*
import seaborn as sns
from scipy import stats
import matplotlib as mpl
import matplotlib.pyplot as plt
#中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# notebook嵌入图片
%matplotlib inline
# 提高分辨率
%config InlineBackend.figure_format='retina'
from ydata_profiling import ProfileReport
import shap
# 忽略警告
import warnings
warnings.filterwarnings('ignore')
Import Data
df = pd.read_csv('/kaggle/input/crab-age-prediction/CrabAgePrediction.csv')
df.head()
output :
Sex Length Diameter Height Weight Shucked Weight Viscera Weight Shell Weight Age
0 F 1.4375 1.1750 0.4125 24.635715 12.332033 5.584852 6.747181 9
1 M 0.8875 0.6500 0.2125 5.400580 2.296310 1.374951 1.559222 6
2 I 1.0375 0.7750 0.2500 7.952035 3.231843 1.601747 2.764076 6
3 F 1.1750 0.8875 0.2500 13.480187 4.748541 2.282135 5.244657 10
4 I 0.8875 0.6625 0.2125 6.903103 3.458639 1.488349 1.700970 6
data analysis
profile = ProfileReport(df, title="Crab data report")
profile.to_notebook_iframe()
- Since CSDN cannot embed html files, here we can only split the complete data report and explain it separately.
overview
- The following information can be obtained from the above figure:
- The data contains a total of 9 features
- The data has a total of 3893 samples
- The data sample does not have any missing and repeated
- There is 1 categorical variable and the other 8 are numerical variables
- The following information can be obtained from the above figure:
- features have a strong correlation
Length
withDiameter
- features have a strong correlation
Length
withHeight
- features have a strong correlation
Length
withWeight
- features have a strong correlation
Length
withShucked Weight
- features have a strong correlation
Length
withViscera Weight
- features have a strong correlation
Length
withShell Weight
- features have a strong correlation
Length
withAge
- features have a strong correlation
variable
- The following information can be obtained from the above figure:
- The feature
Sex
is a categorical variable, with a total of 3 categories - Among them,
M
there are 1435 category samples,I
1233 category samples, andF
1225 category samples.
- The feature
- The following information can be obtained from the above figure:
- The feature
Length
is a numerical variable, a total of 134 values appeared, and the difference ratio was 3.4% - According to the histogram, it can be seen that the data distribution is skewed to the right
- Get various statistical characteristics of this feature, such as maximum and minimum values, average values, quartiles, etc.
- The feature
- The following information can be obtained from the above figure:
- The feature
Diameter
is a numerical variable, a total of 111 values appeared, and the difference ratio was 2.9% - According to the histogram, it can be seen that the data distribution is skewed to the right
- Get various statistical characteristics of this feature, such as maximum and minimum values, average values, quartiles, etc.
- The feature
- The following information can be obtained from the above figure:
- The feature
Height
is a numerical variable, a total of 51 values have appeared, and the difference ratio is 1.3% - According to the histogram, it can be seen that the data distribution is normally distributed, but the maximum value is 2.825, which is a bit abnormal and may need to be considered for elimination.
- There are 2 0 values, accounting for 0.1% of the overall sample. According to experience, the length cannot be 0, so these 2 samples may be abnormal samples and need to be eliminated.
- Get various statistical characteristics of this feature, such as maximum and minimum values, average values, quartiles, etc.
- The feature
- The following information can be obtained from the above figure:
- The feature
Weight
is a numerical variable, a total of 2343 values have appeared, and the difference ratio is 60.2% - According to the histogram, it can be seen that the data distribution is left-biased. Values after 60 need to consider whether they are potential outliers and deal with them accordingly.
- Get various statistical characteristics of this feature, such as maximum and minimum values, average values, quartiles, etc.
- The feature
- The following information can be obtained from the above figure:
- The feature
Shucked Weight
is a numerical variable, a total of 1482 values have appeared, and the difference ratio is 38.1% - According to the histogram, it can be seen that the data distribution has a certain left bias. Values after 30 need to consider whether they are potential outliers and deal with them accordingly.
- Get various statistical characteristics of this feature, such as maximum and minimum values, average values, quartiles, etc.
- The feature
- The following information can be obtained from the above figure:
- The feature
Viscera Weight
is a numerical variable, a total of 867 values appeared, and the difference ratio was 22.3% - According to the histogram, it can be seen that the data distribution has a certain left bias. Values after 15 need to consider whether they are potential outliers and deal with them accordingly.
- Get various statistical characteristics of this feature, such as maximum and minimum values, average values, quartiles, etc.
- The feature
- The following information can be obtained from the above figure:
- The feature
Shell Weight
is a numerical variable, a total of 907 values appeared, and the difference ratio was 23.3% - According to the histogram, it can be seen that the data distribution has a certain left bias. Values after 20 need to consider whether they are potential outliers and deal with them accordingly.
- Get various statistical characteristics of this feature, such as maximum and minimum values, average values, quartiles, etc.
- The feature
- The following information can be obtained from the above figure:
- The feature
Age
is a numerical variable, there are 28 values in total, and the difference ratio is 0.7% - According to the histogram, it can be seen that the data distribution is basically a normal distribution. Values after 25 need to consider whether they are potential outliers and deal with them accordingly.
- Get various statistical characteristics of this feature, such as maximum and minimum values, average values, quartiles, etc.
- The feature
interaction diagram
- Here you can choose two different variables, because the space is limited, here only show the interaction graph with the vertical axis
Weight
and the horizontal axis .Age
Correlation diagram
- The highly relevant information shown in the above figure has been explained in the overview stage, so I won't go into details here.
data processing
- According to the above data analysis, the data is processed accordingly
Height
Remove samples with 0 in the feature- Because
Sex
there is no size relationship, it is one-hot encoded, usingdf = pd.get_dummies(df,columns=["Sex"])
- The isolation forest algorithm in use
sklearn
removes potential outliers, and the proportion of outliers is 0.05 (empirical value) - Standardize (
zscore
) transform the data. Due to the small amount of data, the training set is 0.9, the test set is 0.1, and the average MSE of the 10-fold cross-test is used as the evaluation standard - With Age as the dependent variable and the rest of the features as independent variables, a regression model is constructed
Build models and tune parameters
- Use sklearn to build a variety of regression models, such as
gbr
,catboost
,lightgbm
etc.
ID | Model | MAE | MSE | RMSE | R2 | RMSLE | MAPE | TT (Sec) |
---|---|---|---|---|---|---|---|---|
gbr | Gradient Boosting Regressor | 1.5019 | 4.4889 | 2.1169 | 0.5530 | 0.1727 | 0.1501 | 0.3410 |
cat boost | CatBoost Regressor | 1.5082 | 4.5334 | 2.1277 | 0.5485 | 0.1729 | 0.1505 | 2.6590 |
lightgbm | Light Gradient Boosting Machine | 1.5240 | 4.6280 | 2.1504 | 0.5389 | 0.1751 | 0.1516 | 0.4320 |
rf | Random Forest Regressor | 1.5338 | 4.6551 | 2.1561 | 0.5363 | 0.1762 | 0.1535 | 0.8190 |
et | Extra Trees Regressor | 1.5494 | 4.7462 | 2.1772 | 0.5267 | 0.1780 | 0.1552 | 0.4890 |
ridge | Ridge Regression | 1.5771 | 4.8477 | 2.1979 | 0.5166 | 0.1824 | 0.1596 | 0.0430 |
lr | Linear Regression | 1.5772 | 4.8479 | 2.1980 | 0.5165 | 0.1822 | 0.1596 | 0.4730 |
lar | Least Angle Regression | 1.5772 | 4.8479 | 2.1980 | 0.5165 | 0.1822 | 0.1596 | 0.0470 |
br | Bayesian Ridge | 1.5771 | 4.8482 | 2.1980 | 0.5166 | 0.1824 | 0.1596 | 0.0440 |
huber | Huber Regressor | 1.5435 | 4.9130 | 2.2136 | 0.5105 | 0.1814 | 0.1503 | 0.0620 |
xgboost | Extreme Gradient Boosting | 1.5884 | 5.0390 | 2.2429 | 0.4972 | 0.1822 | 0.1581 | 0.2910 |
knn | K Neighbors Regressor | 1.6147 | 5.1607 | 2.2705 | 0.4856 | 0.1853 | 0.1599 | 0.0500 |
omp | Orthogonal Matching Pursuit | 1.8157 | 6.1171 | 2.4715 | 0.3917 | 0.2084 | 0.1867 | 0.0400 |
in | Elastic Net | 1.8855 | 6.6865 | 2.5833 | 0.3367 | 0.2212 | 0.2007 | 0.0440 |
lasso | Lasso Regression | 1.9536 | 7.1238 | 2.6663 | 0.2937 | 0.2360 | 0.2154 | 0.0440 |
ll | Lasso Least Angle Regression | 1.9536 | 7.1238 | 2.6663 | 0.2937 | 0.2360 | 0.2154 | 0.0420 |
ada | AdaBoost Regressor | 2.2463 | 7.2386 | 2.6873 | 0.2767 | 0.2325 | 0.2479 | 0.1820 |
dt | Decision Tree Regressor | 2.0626 | 8.9308 | 2.9865 | 0.1079 | 0.2389 | 0.2035 | 0.0530 |
par | Passive Aggressive Regressor | 2.2911 | 9.0001 | 2.9784 | 0.0897 | 0.2636 | 0.2401 | 0.0480 |
dummy | Dummy Regressor | 2.3369 | 10.0990 | 3.1743 | -0.0006 | 0.2871 | 0.2672 | 0.0990 |
- It can be found that the lowest
gbr
modelMSE
has the best effect. Use random search, adjust parameters, iterate 20 times, and the best model is 10-fold cross-validation results.
Fold | MAE | MSE | RMSE | R2 | RMSLE | MAPE |
---|---|---|---|---|---|---|
0 | 1.5267 | 4.7045 | 2.1690 | 0.5536 | 0.1719 | 0.1519 |
1 | 1.4646 | 4.3281 | 2.0804 | 0.5695 | 0.1674 | 0.1447 |
2 | 1.5053 | 4.3746 | 2.0915 | 0.5534 | 0.1757 | 0.1562 |
3 | 1.4995 | 4.5526 | 2.1337 | 0.5466 | 0.1731 | 0.1512 |
4 | 1.5798 | 4.6405 | 2.1542 | 0.5627 | 0.1842 | 0.1652 |
5 | 1.3879 | 3.7092 | 1.9259 | 0.6079 | 0.1654 | 0.1462 |
6 | 1.5613 | 4.8066 | 2.1924 | 0.5326 | 0.1744 | 0.1508 |
7 | 1.4739 | 4.4213 | 2.1027 | 0.5899 | 0.1683 | 0.1440 |
8 | 1.4873 | 4.5554 | 2.1343 | 0.6052 | 0.1716 | 0.1502 |
9 | 1.4654 | 4.0750 | 2.0187 | 0.4792 | 0.1658 | 0.1471 |
Mean | 1.4952 | 4.4168 | 2.1003 | 0.5601 | 0.1718 | 0.1508 |
Std | 0.0515 | 0.3082 | 0.0748 | 0.0359 | 0.0053 | 0.0060 |
- The optimal parameters are as follows:
Param | numbers |
---|---|
alpha | 0.9 |
ccp_alpha | 0.0 |
criterion | friedman_mse |
init | None |
learning_rate | 0.1 |
loss | squared_error |
max_depth | 3 |
max_features | None |
max_leaf_nodes | None |
min_impurity_decrease | 0.0 |
min_samples_leaf | 1 |
min_samples_split | 2 |
min_weight_fraction_leaf | 0.0 |
n_estimators | 100 |
n_iter_no_change | None |
random_state | 2023 |
subsample | 1.0 |
tol | 0.0001 |
validation_fraction | 0.1 |
verbose | 0 |
warm_start | False |
model analysis
- Analyze the model and visualize some indicators
Model residual plot
- The model is R 2 R^2 on the training set and the test setR2 is not much different, indicating that there is no overfitting phenomenon
- The residual is evenly distributed on both sides of the 0 line, and presents randomness
- The residual distribution of the model on the training set and the test set is basically the same
Model Learning Curve
- Through the learning curve, we can judge whether the model has overfitting phenomenon. It can be seen that the training set and the verification set in the above figure are converging towards the middle value, indicating that the model is not overfitting.
model interpretation
- This module will use the SHAP package, using the Shapley value to evaluate the impact of features on the model.
try to
- We can take a sample and visualize its prediction process, the code is as follows
# 取训练数据
X = s.get_config('X_train')
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:])
- 红色的代表特征贡献的正方向力(将预测推高),蓝色的表示特征贡献的负方向的力(将预测推低)
部分依赖图
- 为了解单个特征如何影响模型的输出以及特征间的交互作用,我们可以绘制部分依赖图
- 部分预测图遵照下列规则
- 每个点都是数据集中的一个样本
x
轴是特征的值(来自X矩阵,存储在shap_values.data
中)y
轴是该特征的SHAP值(存储在shap_values.values
中)。它表示了该特征的值会在多大程度上改变该样本预测的模型输出。对于这个模型,单位是Age
的对数赔率。- 散点图的颜色由另一个特征决定,如果不传入固定特征,则函数会挑选与分析特征交互最强的特征列(比如在下面与
Length
交互性最强的是Sex_F
)
- 我们绘制出除
Sex
独热编码后的其余列,并让函数自动选择与其交互性较强的特征
for name in X.columns:
if 'Sex' not in name:
shap.dependence_plot(name, shap_values, X)
- 由上图可知:
- 与
Length
交互性最强的是Sex_F
- 长度在1.25以上的雌性螃蟹年龄高于雄性和未知性别螃蟹
- 长度在1.00以下时,随着长度增加年龄也随着增加,但超过1.00以上时,长度增加年龄不一定增加
- 与
- 由上图可知:
- 与
Diameter
交互性最强的是Sex_M
- 螃蟹直径小于0.6时,直径越长,年龄越小。当直径在0.6~0.8时,直径越长,年龄越大,当直径超过0.8时,直径对年龄影响较小
- 直径在0.7以下时,雄性螃蟹年龄基本大于雌性和未知性别
- 与
- 由上图可知:
- 与
Height
交互性最强的是Length
- 螃蟹高度大于0.3时,长度大的螃蟹年龄更大
- 与
- 由上图可知:
- 与
Weight
交互性最强的是Sex_M
- 在重量相同时,雄性螃蟹年龄比雌性和未知性别年龄大
- 与
- 由上图可知:
- 与
Shucked Weight
交互性最强的是Diameter
- 当螃蟹不含壳的重量小于3时,重量越小,年龄越小。
- 去壳重量量越大,螃蟹直径越大
- 与
蜂群摘要图
- 蜂群图旨在显示数据集中的主要特征如何影响模型输出的信息密集摘要。
- 给定解释的每个实例都由每个特征流上的单个点表示。
- 点的 x 位置由该特征的 SHAP 值确定,点沿着每个特征行“堆积”以显示密度。颜色用于显示特征的原始值。
- 由上图可得以下结论:
Sheel Weight
Shell weight ( ) is the most important feature on average , the greater the shell weight, the older the crab- The lighter the weight without shell (
Shucked Weight
), the older it may be. Heavy without shell weight almost younger Length
The longer the length ( ), the younger the age- The smaller the diameter (
Diameter
), the younger the almost - Male (
Sex_M
) and female (Sex_F
), older than unknown (Sex_I
)