神经网络之梯度下降法python代码实现01

0. 前言

上篇博客对神经网络中的梯度下降法的原理进行了说明,下面希望实现以下该文中的例子,也就是梯度下降在求解函数最小值的这个过程。

1. 二次函数最小值求解过程及其可视化

对某个二次函数求解最小值

import numpy as np


def function(x):
    """原函数"""
    return 0.5 * (x - 0.3) ** 2


def function_derivative(x):
    """求导后的函数"""
    return 0.5 * 2 * (x - 0.3)


GD_x = []  # 存放自变量x的值
GD_y = []  # 存放函数值
x = 4  # 初始自变量的值
alpha = 0.5  # 步长
f_x = function(x)
f_now = f_x
GD_x.append(x)
GD_y.append(f_now)
iter_num = 0  # 迭代次数

while f_x > 1e-10 and iter_num < 100:  # 不能无限迭代,y也不能太小
    iter_num += 1
    x = x - alpha * function_derivative(x)  # 根据步长和求导正负确定下一次的x增加或减小
    temp = function(x)  # 判断y的变化,不能太小
    f_x = np.abs(f_now - temp)
    f_now = temp
    GD_x.append(x)
    GD_y.append(f_now)

print("最终结果:x = {:.2f}, y = {:.2f}".format(x, f_now))
print("迭代次数:", iter_num)
print("GD_x:", GD_x)
print("GD_y:", GD_y)

输出结果:

最终结果:x = 0.30, y = 0.00
迭代次数: 19
GD_x: [4, 2.15, 1.225, 0.7625000000000001, 0.53125, 0.415625, 0.3578125, 0.32890624999999996, 0.314453125, 0.3072265625, 0.30361328125, 0.30180664062499996, 0.3009033203125, 0.30045166015625, 0.300225830078125, 0.30011291503906246, 0.30005645751953125, 0.30002822875976565, 0.3000141143798828, 0.30000705718994136]
GD_y: [6.845000000000001, 1.7112499999999997, 0.42781250000000004, 0.10695312500000004, 0.026738281250000002, 0.006684570312500004, 0.0016711425781249993, 0.00041778564453124906, 0.00010444641113281266, 2.6111602783203365e-05, 6.527900695800741e-06, 1.6319751739501351e-06, 4.0799379348755887e-07, 1.0199844837190225e-07, 2.5499612092969296e-08, 6.37490302323919e-09, 1.5937257558113643e-09, 3.9843143895362463e-10, 9.96078597380144e-11, 2.4901964934307722e-11]

根据结果可以看到,迭代19次就可以求得最小值了,GD_x和GD_y
保存了这19次中x的值和对应的y的值。

可视化
为了更直观,所以把上面的过程画出来

from matplotlib import pyplot as plt
# 可视化
X = np.arange(-4, 4.5, 0.06)
Y = np.array(list(map(lambda t: function(t), X)))

plt.figure()
plt.title('$y = 0.5*(x - 0.3)^2$\nLearning rate: {:.2f}; Result: ({:.2f},{:.2f}); Iteration num: {}'.format(alpha, x, f_x, iter_num))
plt.plot(X, Y, 'g-')
plt.plot(GD_x, GD_y, 'ro--')
plt.show()

可视化结果:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_41857483/article/details/109731500
今日推荐