Table of contents
1. Configure the virtual environment
2. Library version introduction
1. Experiment introduction
This experiment completed the Notears Linear algorithm's causality estimation in linear structural equation models.
ChatGPT:
The Notears Linear algorithm is an efficient method for estimating causal relationships in linear structural equation models. It finds the optimal weight matrix by minimizing the loss function so that the matrix can describe the causal relationship between variables. This algorithm has the following advantages:
Efficiency: The Notears Linear algorithm uses an optimization-based approach that can efficiently estimate causal relationships. The complexity of the algorithm depends on the number of variables and the size of the observation sample, but generally has low computational complexity.
Linear model applicability: Notears Linear algorithm is suitable for linear structural equation models and can effectively handle linear causal relationships. For nonlinear relationships, this algorithm may not work.
Introduction of constraint terms: Notears Linear algorithm introduces constraint terms to ensure that the estimated graph is acyclic, thus establishing the causality of causal relationships. This helps improve the interpretability and reliability of the estimation results.
2. Experimental environment
This series of experiments uses the PyTorch deep learning framework, and the relevant operations are as follows (based on the environment of the deep learning series of articles):
1. Configure the virtual environment
The environment of the deep learning series of articles
conda create -n DL python=3.7
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
conda install scikit-learn
New addition
conda install pandas
conda install seaborn
conda install networkx
conda install statsmodels
pip install pyHSICLasso
Note: My experimental environment installs various libraries in the above order. If you want to try installing them together (God knows if there will be any problems)
2. Library version introduction
software package | This experimental version | The latest version currently |
matplotlib | 3.5.3 | 3.8.0 |
numpy | 1.21.6 | 1.26.0 |
python | 3.7.16 | |
scikit-learn | 0.22.1 | 1.3.0 |
torch | 1.8.1+cu102 | 2.0.1 |
torchaudio | 0.8.1 | 2.0.2 |
torchvision | 0.9.1+cu102 | 0.15.2 |
New
networkx | 2.6.3 | 3.1 |
pandas | 1.2.3 | 2.1.1 |
pyHSICLase | 1.4.2 | 1.4.2 |
seaborn | 0.12.2 | 0.13.0 |
state models | 0.13.5 | 0.14.0 |
3. IDE
It is recommended to use Pycharm (among them, the pyHSICLasso library has an error in VScode, and a solution has not yet been found...)
internal function
3. Experimental content
0. Import necessary tools
import numpy as np
import scipy.linalg as slin
import scipy.optimize as sopt
import random
import networkx as nx
import matplotlib.pyplot as plt
1. set_random_seed
def set_random_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
Used to set a random seed to ensure reproducible results.
2. notaries_linear
def notears_linear(X, lambda1=0.08, max_iter=100, h_tol=1e-8, rho_max=1e+16, w_threshold=0.3):
a. Input parameters
X
: The input data matrix has the shape of(n, d)
, wheren
is the number of samples andd
is the feature dimension.lambda1
: The weight of the L1 regularization term, the default is 0.08.max_iter
: Maximum number of iterations, default is 100.h_tol
: Target error tolerance to stop iteration, default is 1e-8.rho_max
: Maximum penalty parameter, default is 1e+16.w_threshold
: The threshold of the weight, used for sparse estimation results, the default is 0.3.
Several auxiliary functions are defined inside the function, including
b. Internal functions_adj
def _adj(w):
return w.reshape([d, d])
Convert the flattened weight parameters w
into a weight matrix in the form of a square matrixW。
c.
internal function_loss
def _loss(W):
X_ = X @ W
R = X - X_
loss = 0.5 / X.shape[0] * (R ** 2).sum() + lambda1 * W.sum() # Form 2
G_loss = - 1.0 / X.shape[0] * X.T @ R + lambda1
return loss, G_loss
- Compute the loss function and its gradient.
- The loss function consists of two parts: data fitting term and regularization term.
- The gradient represents the derivative of the loss function with respect to the weight matrix.
d.
internal function_h
def _h(W):
E = slin.expm(W * W)
h = np.trace(E) - d
G_h = E.T * W * 2 # Form 6
return h, G_h
- Compute another constraint and its gradient.
- Constraints are used to ensure that the estimated graph is acyclic.
- The gradient represents the derivative of the constraint term with respect to the weight matrix.
e.
internal function_func
def _func(w):
W = _adj(w)
loss, G_loss = _loss(W)
h, G_h = _h(W)
obj = loss + 0.5 * rho * h * h + alpha * h # Form 11
G_smooth = G_loss + (rho * h + alpha) * G_h # G of Form 11
g_obj = G_smooth.reshape(-1, )
return obj, g_obj
- Compute the complete objective function and its gradient.
- The objective function includes loss function, constraint term and regularization term.
- The gradient represents the derivative of the objective function with respect to the weight parameters.
f. Function body
n, d = X.shape
w_est, rho, alpha, h = np.zeros(d * d), 1.0, 0.0, np.inf
bnds = [(0, 0) if i == j else (0, None) for i in range(d) for j in range(d)]
X = X - np.mean(X, axis=0)
for _ in range(max_iter):
w_new, h_new = None, None
while rho < rho_max:
sol = sopt.minimize(_func, w_est, jac=True, bounds=bnds)
w_new = sol.x
h_new, _ = _h(_adj(w_new))
if h_new > 0.25 * h: # h下降不够快时 提高h的权重
rho *= 10
else:
break
w_est, h = w_new, h_new
alpha += rho * h
if h <= h_tol or rho >= rho_max:
break
W_est = _adj(w_est)
W_est[np.abs(W_est) < w_threshold] = 0
return W_est
- Initialize variables
- Get the dimensions of the input data matrix and initialize some variables.
-
Center the input data matrix.
- Loop iteration
-
In each iteration,
scipy.optimize.minimize
the optimal model parameter estimates are found by calling the function. -
In the process of finding the optimal solution,
rho
the sparsity of the model structure is controlled by adjusting the value of the penalty parameter. -
During the iterative process,
rho
the value is dynamically adjusted based on changes in the value of the objective function and constraints. -
When the stopping condition is reached (the target error is less than the tolerance or the penalty parameter
rho
reaches the maximum value), the iteration is stopped.
-
- Thresholding: Set elements in the weight matrix whose absolute value is less than a given threshold to zero.
- Returns the estimated model parameter matrix
W_est
.
3. Main program
if __name__ == '__main__':
set_random_seed()
X = np.loadtxt('Notears_X.csv', delimiter=',')
W_est = notears_linear(X)
print("W_est")
print(W_est)
G_nx = nx.DiGraph(W_est)
print(nx.is_directed_acyclic_graph(G_nx))
nx.draw_planar(G_nx, with_labels=True)
plt.show()
- Set random seed.
- Load the data matrix from the file "Notears_X.csv"
X
.
- Call
notears_linear
the function to estimate the parameters of the linear structural equation model and obtain the estimated model parameter matrixW_est
. - Prints the estimated model parameter matrix
W_est
. W_est
Create a directed graph based on the estimated model parameter matrixG_nx
.- Determine
G_nx
whether the graph is a directed acyclic graph (DAG). - Draw
G_nx
the plan layout of the diagram and display the graph.
Data part display
6.24E-01 | 9.07E-01 | 7.77E-01 | 1.58E+00 | ####### | ####### | 5.62E+00 | ####### | ####### | 7.16E+00 |
7.50E-01 | 7.33E-01 | ####### | 7.01E-03 | ####### | ####### | 3.93E-01 | ####### | 2.40E+00 | ####### |
3.77E-01 | 7.12E-01 | 1.71E-01 | 1.58E-01 | 1.08E+00 | 1.73E+00 | ####### | 3.05E+00 | 4.09E+00 | ####### |
1.39E-01 | 1.10E+00 | 7.96E-01 | 1.67E+00 | 2.94E-01 | ####### | 4.86E+00 | ####### | ####### | 7.24E+00 |
####### | ####### | ####### | ####### | ####### | 9.83E-01 | ####### | 3.42E+00 | 4.28E+00 | ####### |
####### | ####### | 8.44E-01 | 5.92E-01 | 9.75E-02 | ####### | ####### | 8.99E-01 | ####### | 1.18E+00 |
####### | 1.68E-01 | ####### | ####### | 1.50E+00 | 3.22E+00 | ####### | 3.14E+00 | 4.26E+00 | ####### |
####### | ####### | 2.18E-01 | ####### | 1.18E+00 | 2.19E+00 | ####### | 1.41E+00 | 9.86E-01 | ####### |
1.85E-01 | 3.48E-02 | 3.65E-01 | ####### | 3.91E-01 | 1.97E+00 | ####### | 4.16E+00 | 4.85E+00 | ####### |
####### | ####### | ####### | ####### | 3.01E-01 | 7.11E-01 | 2.77E-02 | ####### | ####### | ####### |
Draw a graph
n. code integration
import numpy as np
import scipy.linalg as slin
import scipy.optimize as sopt
import random
import networkx as nx
import matplotlib.pyplot as plt
def set_random_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
def notears_linear(X, lambda1=0.08, max_iter=100, h_tol=1e-8, rho_max=1e+16, w_threshold=0.3):
def _adj(w):
return w.reshape([d, d])
def _loss(W):
X_ = X @ W
R = X - X_
loss = 0.5 / X.shape[0] * (R ** 2).sum() + lambda1 * W.sum() # Form 2
G_loss = - 1.0 / X.shape[0] * X.T @ R + lambda1
return loss, G_loss
def _h(W):
E = slin.expm(W * W)
h = np.trace(E) - d
G_h = E.T * W * 2 # Form 6
return h, G_h
def _func(w):
W = _adj(w)
loss, G_loss = _loss(W)
h, G_h = _h(W)
obj = loss + 0.5 * rho * h * h + alpha * h # Form 11
G_smooth = G_loss + (rho * h + alpha) * G_h # G of Form 11
g_obj = G_smooth.reshape(-1, )
return obj, g_obj
n, d = X.shape
w_est, rho, alpha, h = np.zeros(d * d), 1.0, 0.0, np.inf
bnds = [(0, 0) if i == j else (0, None) for i in range(d) for j in range(d)]
X = X - np.mean(X, axis=0)
for _ in range(max_iter):
w_new, h_new = None, None
while rho < rho_max:
sol = sopt.minimize(_func, w_est, jac=True, bounds=bnds)
w_new = sol.x
h_new, _ = _h(_adj(w_new))
if h_new > 0.25 * h: # h下降不够快时 提高h的权重
rho *= 10
else:
break
w_est, h = w_new, h_new
alpha += rho * h
if h <= h_tol or rho >= rho_max:
break
W_est = _adj(w_est)
W_est[np.abs(W_est) < w_threshold] = 0
return W_est
if __name__ == '__main__':
set_random_seed()
X = np.loadtxt('Notears_X.csv', delimiter=',')
W_est = notears_linear(X)
print("W_est")
print(W_est)
G_nx = nx.DiGraph(W_est)
print(nx.is_directed_acyclic_graph(G_nx))
nx.draw_planar(G_nx, with_labels=True)
plt.show()
# edges, weights = zip(*nx.get_edge_attributes(G_nx, 'weight').items())
# pos = nx.spring_layout(G_nx)
# nx.draw(G_nx, pos, node_color='b', edgelist=edges, edge_color=weights, width=5, with_labels=True, edge_cmap=plt.cm.Blues)
# plt.show()