Gradient descent principle and Python implementation

The gradient descent algorithm is a very basic algorithm and plays a very important role in machine learning and optimization. This article first introduces the basic concepts of gradient descent, and then uses python to implement a basic gradient descent algorithm. There are many variants of gradient descent. This article only introduces the most basic gradient descent, that is, batch gradient descent.

The practical application examples will not be described in detail. There are many application examples of gradient descent on the Internet. The most common example is the example of predicting housing prices in the NG class: 
Suppose there is a house sales data as follows:

Area (m^2) Sales price (10,000 yuan)

Area (m^2) Sales price (ten thousand yuan)
123 250
150 320
87 180

According to the above price, we can make such a picture:

write picture description here

So our goal is to fit this graph, so that we can easily make predictions after new sample data comes in: 
write picture description here

For the most basic linear regression problem, the formula is as follows: 
write picture description here 
x is the independent variable, say the size of the house. θ is the weight parameter, which is the specific value we need to solve for gradient descent.

Here, we need to introduce a loss function (or cost function), which is used to measure whether our updated parameters are moving in the right direction during gradient descent, as shown in the loss function (m represents the training set sample) Quantity): 
write picture description here 
The following figure visually shows the direction of our gradient descent, which is to hope to descend from the highest to the lowest: 
write picture description here

In the process of updating the weight parameters by gradient descent, we need to obtain the partial derivative of the loss function: 
write picture description here 
after the partial derivative is obtained, the parameters can be updated: the 
write picture description here 
pseudo code is shown in the figure: 
write picture description here

Well, now comes the code implementation link, we use Python to implement a gradient descent algorithm to solve:

 
y = 2x1 + x2 + 3 y =2x1+ x2 + 3

, that is, to solve:

 
y=ax1+bx2+cy=ax1+bx2+c

The three parameters of a, b, and c in .

 

Here is the code:

import numpy as np
import matplotlib.pyplot as plt
#y=2 * (x1) + (x2) + 3 rate = 0.001 x_train = np.array([ [1, 2], [2, 1], [2, 3], [3, 5], [1, 3], [4, 2], [7, 3], [4, 5], [11, 3], [8, 7] ]) y_train = np.array([7, 8, 10, 14, 8, 13, 20, 16, 28, 26]) x_test = np.array([ [1, 4], [2, 2], [2, 5], [5, 3], [1, 5], [4, 1] ]) a = np.random.normal() b = np.random.normal() c = np.random.normal() def h(x): return a*x[0]+b*x[1]+c for i in range(10000): sum_a=0 sum_b=0 sum_c=0 for x, y in zip(x_train, y_train): sum_a = sum_a + rate*(y-h(x))*x[0] sum_b = sum_b + rate*(y-h(x))*x[1] sum_c = sum_c + rate*(y-h(x)) a = a + sum_a b = b + sum_b c = c + sum_c plt.plot([h(xi) for xi in x_test]) print(a) print(b) print(c) result=[h(xi) for xi in x_train] print(result) result=[h(xi) for xi in x_test] print(result) plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

x_train is the training set x, y_train is the training set y, and x_test is the test set x. After running, the following picture is obtained. The picture shows how the algorithm's prediction for the test set y changes in each iteration: 
write picture description here

We can see that the line segment is gradually approaching, the more training data and the more iterations, the closer to the true value.

Reference article: 
http://www.cnblogs.com/LeftNotEasy/archive/2010/12/05/mathmatic_in_machine_learning_1_regression_and_gradient_descent.html

http://www.cnblogs.com/eczhou/p/3951861.html

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325398056&siteId=291194637