Gradient descent algorithm
First, the principle
Gradient descent algorithm is used to solve a minimum function algorithm, detailed description of the algorithm will not elaborate formulas directly
The entire cost function Andrew Ng machine learning to illustrate
-
Cost function
\[J(\theta_0,\theta_1) \] -
Gradient descent algorithm
\[\begin{align} \theta_j &= \theta_j - \alpha\frac{\partial}{\partial\theta_j}J(\theta_0, \theta_1) \\ j &= 0,1 \end{align} \]
Second, the practice (python)
The objective function
\ [F (x, y) = -e ^ {- (x ^ 2 + y ^ 2)} \]
import numpy as np
import math
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
x = np.linspace(-2, 2, 50)
y = np.linspace(-2, 2, 50)
X,Y = np.meshgrid(x,y)
Z = -np.exp(-(X**2 + Y**2))
fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(X, Y, Z, rstride = 1, cstride = 1, cmap = 'rainbow')
plt.show()
-
Function is relatively simple example here, the minimum value acquired at (0, 0)
-
code show as below
import math import numpy as np # x type: list, x[0], x[1] def grad_2d(x): temp0 = 2 * x[0] * math.exp(-(x[0]**2 + x[1]**2)) temp1 = 2 * x[1] * math.exp(-(x[0]**2 + x[1]**2)) return np.array([temp0, temp1]) def gradient(grad, init_val = np.array([0, 0]), learning_rate = 0.01, precision = 0.0001, max_iters = 10000): print('init val:', init_val) cur_val = init_val for i in range(max_iters): grad_cur = grad(cur_val) if np.linalg.norm(grad_cur, ord=2) < precision: break cur_val = cur_val - grad_cur * learning_rate print('第', i, '次迭代, 当前 x 为:', cur_val) print('min x = ', cur_val) return cur_val if __name__ == '__main__': gradient(grad_2d, init_val=np.array([1, -1]), learning_rate=0.2, precision=0.000001, max_iters=10000)
-
operation result
init val: [ 1 -1] 第 0 次迭代, 当前 x 为: [ 0.94586589 -0.94586589] 第 1 次迭代, 当前 x 为: [ 0.88265443 -0.88265443] 第 2 次迭代, 当前 x 为: [ 0.80832661 -0.80832661] 第 3 次迭代, 当前 x 为: [ 0.72080448 -0.72080448] 第 4 次迭代, 当前 x 为: [ 0.61880589 -0.61880589] 第 5 次迭代, 当前 x 为: [ 0.50372222 -0.50372222] 第 6 次迭代, 当前 x 为: [ 0.3824228 -0.3824228] 第 7 次迭代, 当前 x 为: [ 0.26824673 -0.26824673] 第 8 次迭代, 当前 x 为: [ 0.17532999 -0.17532999] 第 9 次迭代, 当前 x 为: [ 0.10937992 -0.10937992] 第 10 次迭代, 当前 x 为: [ 0.06666242 -0.06666242] 第 11 次迭代, 当前 x 为: [ 0.04023339 -0.04023339] 第 12 次迭代, 当前 x 为: [ 0.02419205 -0.02419205] 第 13 次迭代, 当前 x 为: [ 0.01452655 -0.01452655] 第 14 次迭代, 当前 x 为: [ 0.00871838 -0.00871838] 第 15 次迭代, 当前 x 为: [ 0.00523156 -0.00523156] 第 16 次迭代, 当前 x 为: [ 0.00313905 -0.00313905] 第 17 次迭代, 当前 x 为: [ 0.00188346 -0.00188346] 第 18 次迭代, 当前 x 为: [ 0.00113008 -0.00113008] 第 19 次迭代, 当前 x 为: [ 0.00067805 -0.00067805] 第 20 次迭代, 当前 x 为: [ 0.00040683 -0.00040683] 第 21 次迭代, 当前 x 为: [ 0.0002441 -0.0002441] 第 22 次迭代, 当前 x 为: [ 0.00014646 -0.00014646] 第 23 次迭代, 当前 x 为: [ 8.78751305e-05 -8.78751305e-05] 第 24 次迭代, 当前 x 为: [ 5.27250788e-05 -5.27250788e-05] 第 25 次迭代, 当前 x 为: [ 3.16350474e-05 -3.16350474e-05] 第 26 次迭代, 当前 x 为: [ 1.89810285e-05 -1.89810285e-05] 第 27 次迭代, 当前 x 为: [ 1.13886171e-05 -1.13886171e-05] 第 28 次迭代, 当前 x 为: [ 6.83317026e-06 -6.83317026e-06] 第 29 次迭代, 当前 x 为: [ 4.09990215e-06 -4.09990215e-06] 第 30 次迭代, 当前 x 为: [ 2.45994129e-06 -2.45994129e-06] 第 31 次迭代, 当前 x 为: [ 1.47596478e-06 -1.47596478e-06] 第 32 次迭代, 当前 x 为: [ 8.85578865e-07 -8.85578865e-07] 第 33 次迭代, 当前 x 为: [ 5.31347319e-07 -5.31347319e-07] 第 34 次迭代, 当前 x 为: [ 3.18808392e-07 -3.18808392e-07] min x = [ 3.18808392e-07 -3.18808392e-07]