python: draw GAM nonlinear regression scatter plot and fitting curve

Author: CSDN @ _Yakult_

This article will introduce the use of python language to draw generalized linear model (Generalized Additive Model, GAM) nonlinear regression scatter plot and fitting curve. And recorded the code for calculating RMSE, ubRMSE, R2, Bias.

insert image description here



1. Detailed explanation of GAM nonlinear regression

GAM (Generalized Additive Model) is a flexible statistical model for nonlinear regression and classification. It is an extension of generalized linear models (GLMs) that can model various types of nonlinear relationships.

In GAM, the target variable is assumed to consist of a linear combination of a set of smooth functions and a possible link function. These smoothing functions can capture nonlinear relationships between independent and dependent variables. By associating each independent variable with one or more smoothing functions, GAMs can flexibly model various nonlinear relationships.

The general form of GAM can be expressed as:

y = f 1 ( x 1 ) + f 2 ( x 2 ) + . . . + f p ( x p ) + ε y = f1(x1) + f2(x2) + ... + fp(xp) + ε y=f 1 ( x 1 )+f 2 ( x 2 )+...+fp(xp)+ε
where y is the target variable, x1, x2, …, xp are the independent variables, f1, f2, …, fp are the smoothing functions, and ε is the error term.

In GAM, commonly used smoothing functions include spline function, natural spline function, local spline function, etc. These smoothing functions smooth the data to better capture non-linear relationships between variables.

The modeling process of GAM usually involves the following steps:

  1. Data preparation: including obtaining the data of independent variables and target variables, and performing necessary data preprocessing.

  2. Selection of smoothing function: According to the characteristics of independent variables and the assumption of nonlinear relationship, select an appropriate smoothing function. Common choices include splines, natural splines, etc.

  3. Fitting GAM model: Combine the independent variable and smoothing function to fit the GAM model. The fitting process can be carried out using methods such as least squares estimation and generalized least squares estimation.

  4. Model evaluation: Evaluate the performance of the fitted GAM model, including checking the goodness of fit of the model, residual analysis, etc.

  5. Prediction and Inference: Use the trained GAM model to make predictions and perform inference analysis.

GAMs have many advantages, including:

  • Flexibility: GAM can flexibly model various nonlinear relationships and is suitable for various complex data patterns.

  • Interpretability: Since each independent variable is associated with a smooth function, the results of GAM can well explain the relationship between the independent variable and the target variable.

  • Robustness: GAM is somewhat robust to outliers and noise.

  • Automatic feature selection: GAM can automatically select independent variables related to the target variable through the selection of smoothing functions.

However, GAM also has some limitations and caveats:

  • Selection of smoothing function: Selecting an appropriate smoothing function is one of the key steps, which needs to be reasonably selected according to the characteristics of the data and research questions.

  • Multiple comparisons problem: When modeling with multiple smoothing functions, a multiple comparisons correction is required to avoid increased error due to the large number of estimated smoothing functions.

  • Computational Complexity: Compared to linear models, GAMs have high computational complexity, especially when dealing with large-scale datasets.

Overall, GAM is a powerful nonlinear modeling tool that can help us better understand nonlinear relationships in data. With reasonable choice of smoothing function and proper model evaluation, GAM can be applied in practical problems for predictive and inferential analysis.

Two, the code

import numpy as np
import matplotlib.pyplot as plt
from pygam import LinearGAM, s, f


# 生成模拟数据
np.random.seed(42)
n = 100
X = np.linspace(0, 10, n)
y = np.sin(X) + np.random.normal(0, 0.1, n)


# 拟合 GAM 模型
gam = LinearGAM(s(0)).fit(X, y)


# 绘制相关性图
fig, axs = plt.subplots(1, 1, figsize=(8, 6))
XX = gam.generate_X_grid(term=0, n=100)
pdep, confi = gam.partial_dependence(term=0, X=XX, width=0.95)
axs.plot(XX[:, 0], pdep, color='blue', label='Partial Dependence')
axs.fill_between(XX[:, 0], confi[:, 0], confi[:, 1], color='blue', alpha=0.3)
axs.scatter(X, y, color='black', alpha=0.5, label='Data')
axs.set_xlabel('X', fontsize=12)
axs.set_ylabel('y', fontsize=12)
axs.set_title('Correlation Plot - GAM', fontsize=14)
axs.legend()

plt.tight_layout()
plt.show()

3. Calculate RMSE, ubRMSE, R2, Bias

import numpy as np
import matplotlib.pyplot as plt
from pygam import LinearGAM, s, f
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import linregress


# 生成模拟数据
np.random.seed(42)
n = 100
X = np.linspace(0, 10, n)
y = np.sin(X) + np.random.normal(0, 0.1, n)

# 拟合 GAM 模型
gam = LinearGAM(s(0)).fit(X, y)

# 绘制相关性图
fig, axs = plt.subplots(1, 1, figsize=(8, 6))
XX = gam.generate_X_grid(term=0, n=100).flatten()
print(len(XX))


y_true, y_pred = y, XX
# 计算 RMSE
rmse = np.sqrt(mean_squared_error(y_true, y_pred))

# 计算 R2
r2 = r2_score(y_true, y_pred)

# 计算 Bias
bias = np.mean(y_pred - y_true)

# 计算 ubRMSE
slope, intercept, _, _, _ = linregress(y_pred, y_true)
ubrmse = np.sqrt(np.mean((y_true - (intercept + slope * y_pred))**2))

pdep, confi = gam.partial_dependence(term=0, X=XX, width=0.95)
axs.plot(XX[:], pdep, color='blue', label='Partial Dependence')
axs.fill_between(XX[:], confi[:, 0], confi[:, 1], color='blue', alpha=0.3)
axs.scatter(X, y, color='black', alpha=0.5, label='Data')
axs.set_xlabel('X', fontsize=12)
axs.set_ylabel('y', fontsize=12)
axs.set_title('Correlation Plot - GAM', fontsize=14)
axs.legend()

# 将指标写入图形
textstr = f'RMSE = {
      
      rmse:.4f}\nR2 = {
      
      r2:.4f}\nBias = {
      
      bias:.4f}\nubRMSE = {
      
      ubrmse:.4f}'
props = dict(boxstyle='round', facecolor='white', alpha=0.5)
# 'top', 'bottom', 'center', 'baseline', 'center_baseline'
axs.text(0.05, 0.95, textstr, transform=axs.transAxes, fontsize=12,
         verticalalignment='center_baseline', bbox=props)

plt.tight_layout()
plt.show()

Disclaimer:
As an author, I attach great importance to my own works and intellectual property rights. I hereby declare that all my original articles are protected by copyright law, and no one may publish them publicly without my authorization.
My articles have been paid for publication on some well-known platforms. I hope readers can respect intellectual property rights and refrain from infringement. Any free or paid (including commercial) publishing of paid articles on the Internet without my authorization will be regarded as a violation of my copyright, and I reserve the right to pursue legal responsibility.
Thank you readers for your attention and support to my article!

Guess you like

Origin blog.csdn.net/qq_35591253/article/details/130938485