Matrix factorization with tensorflow

Inversion methods using least squares are inefficient in most cases. In particular, authority towns are less efficient when they are very large. Another implementation method is matrix decomposition, which uses the built-in Cholesky matrix decomposition method of tensorflow. The Cholesky matrix factorization method decomposes a matrix into upper and lower triangular matrices, L and L'. Solve for Ax=b and rewrite it as LL'=b. First solve Ly=b, then solve L'x=y to get the coefficient matrix.

1. Import the programming library, initialize the calculation graph, and generate the data set. Then get the matrices A and b.

>>> import matplotlib.pyplot as plt
>>> import numpy as np

>>> import tensorflow as tf

>>> from tensorflow.python.framework import ops
>>> ops.reset_default_graph()

>>> sess=tf.Session()

>>> x_vals=np.linspace(0,10,100)

>>> y_vals=x_vals+np.random.normal(0,1,100)

>>> x_vals_column=np.transpose(np.matrix(x_vals))
>>> ones_column=np.transpose(np.matrix(np.repeat(1,100)))
>>> A=np.column_stack((x_vals_column,ones_column))
>>> b=np.transpose(np.matrix(y_vals))
>>> A_tensor=tf.constant(A)

>>> b_tensor=tf.constant(b)

2. Find the Cholesky matrix factorization of the square matrix.

Note: tensorflow's cholesky() function only returns the lower triangular matrix of the matrix factorization, because the upper triangular matrix is ​​the transpose of the lower triangular matrix.

>>> tA_A=tf.matmul(tf.transpose(A_tensor),A_tensor)
>>> L=tf.cholesky(tA_A)
>>> tA_b=tf.matmul(tf.transpose(A_tensor),b)
>>> sol1=tf.matrix_solve(L,tA_b)

>>> sol2=tf.matrix_solve(tf.transpose(L),sol1)

3. Extraction factor

>>> solution_eval=sess.run(sol2)
>>> solution_eval
array([[1.01379067],
       [0.02290901]])
>>> slope=solution_eval[0][0]
>>> y_intercept=solution_eval[1][0]
>>> print('slope:'+str(slope))
slope:1.0137906744047482
>>> print('y_intercept:'+str(y_intercept))
y_intercept:0.022909011828880693
>>> best_fit=[]
>>> for i in x_vals:
...   best_fit.append(slope*i+y_intercept)
...
>>> plt.plot(x_vals,y_vals,'o',label='Data')
[<matplotlib.lines.Line2D object at 0x000001E0A58DD9B0>]
>>> plt.plot(x_vals,best_fit,'r-',label='Best fit line',linewidth=3)
[<matplotlib.lines.Line2D object at 0x000001E0A2DFAF98>]
>>> plt.legend(loc='upper left')
<matplotlib.legend.Legend object at 0x000001E0A58F03C8>

>>> plt.show()


Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325656913&siteId=291194637