import numpy as np
import matplotlib.pyplot as plt
# Default parameters for plots
plt.rcParams['font.size'] = 16
plt.rcParams['font.family'] = ['STKaiti']
plt.rcParams['axes.unicode_minus'] = False
def compute_error(b, w, data):
totolError = 0
for i in range(0,len(data)):
x = data[i,0]
y = data[i,1]
totolError += (w * x + b - y)**2
return totolError / (len(data))
def gradient_descend(b, w, learning_rate, iterations, data):
new_b = b
new_w = w
losses = []
for i in range(0,iterations):
loss = compute_error(new_b, new_w, data)
losses.append(loss)
new_b,new_w = compute_gradient(new_b,new_w,np.array(data), learning_rate)
if i%50 == 0: # 打印误差和实时的 w,b 值
print(f"iteration:{i}, loss:{compute_error(new_b, new_w, data)}, w:{new_w}, b:{new_b}")
return [new_b ,new_w],losses
def run():
data = np.genfromtxt("data.csv",delimiter = ',')
initial_b = 0
initial_w = 0
learning_rate = 0.0001
iterations = 1000
print('Starting gradient descent at b = {0}, w = {1}, error = {2}'.format(initial_b,initial_w,compute_error(initial_b,initial_w, data) ))
# 训练优化 1000 次,返回最优 w*,b*和训练 Loss 的下降过程
[b,w],losses = gradient_descend(initial_b, initial_w,learning_rate ,iterations, data)
print('running...')
[b,w],losses = gradient_descend(initial_b, initial_w,learning_rate ,iterations, data)
print('After{0}iterations,b = {1}, w = {2}, error = {3}'.format(iterations,b,w,compute_error(b, w, data) ))
x = [i for i in range(1000)]
plt.plot(x,losses,color='C1', label = '均方差')
plt.xlabel('Epoch')
plt.ylabel('ERROR')
plt.legend()
plt.show()
if __name__ == '__main__':
run()
输出:
Starting gradient descent at b = 0, w = 0, error = 5565.107834490552 iteration:0, loss:1484.5865573886724, w:0.7370702973620529, b:0.014547010110780006 iteration:50, loss:112.64882487923404, w:1.478860560860723, b:0.03213199290371721 iteration:100, loss:112.64702055487857, w:1.4788015372775043, b:0.0351350200320317 iteration:150, loss:112.64521759187215, w:1.4787425359649087, b:0.03813691406602341 iteration:200, loss:112.64341598918755, w:1.4786835569145336, b:0.04113767543322853 iteration:250, loss:112.64161574579843, w:1.4786246001179788, b:0.04413730456102194 iteration:300, loss:112.63981686067933, w:1.4785656655668482, b:0.047135801876617194 iteration:350, loss:112.63801933280529, w:1.478506753252748, b:0.05013316780706675 iteration:400, loss:112.63622316115239, w:1.4784478631672875, b:0.053129402779261856 iteration:450, loss:112.63442834469728, w:1.4783889953020797, b:0.056124507219932715 iteration:500, loss:112.63263488241765, w:1.4783301496487404, b:0.05911848155564851 iteration:550, loss:112.63084277329152, w:1.478271326198889, b:0.06211132621281753 iteration:600, loss:112.62905201629816, w:1.4782125249441473, b:0.06510304161768705 iteration:650, loss:112.62726261041726, w:1.4781537458761407, b:0.06809362819634364 iteration:700, loss:112.62547455462945, w:1.478094988986498, b:0.07108308637471301 iteration:750, loss:112.62368784791623, w:1.4780362542668513, b:0.07407141657856024 iteration:800, loss:112.62190248925954, w:1.4779775417088346, b:0.07705861923348975 iteration:850, loss:112.62011847764245, w:1.4779188513040864, b:0.08004469476494522 iteration:900, loss:112.61833581204849, w:1.477860183044248, b:0.0830296435982101 iteration:950, loss:112.61655449146214, w:1.4778015369209634, b:0.08601346615840703 running... iteration:0, loss:1484.5865573886724, w:0.7370702973620529, b:0.014547010110780006 iteration:50, loss:112.64882487923404, w:1.478860560860723, b:0.03213199290371721 iteration:100, loss:112.64702055487857, w:1.4788015372775043, b:0.0351350200320317 iteration:150, loss:112.64521759187215, w:1.4787425359649087, b:0.03813691406602341 iteration:200, loss:112.64341598918755, w:1.4786835569145336, b:0.04113767543322853 iteration:250, loss:112.64161574579843, w:1.4786246001179788, b:0.04413730456102194 iteration:300, loss:112.63981686067933, w:1.4785656655668482, b:0.047135801876617194 iteration:350, loss:112.63801933280529, w:1.478506753252748, b:0.05013316780706675 iteration:400, loss:112.63622316115239, w:1.4784478631672875, b:0.053129402779261856 iteration:450, loss:112.63442834469728, w:1.4783889953020797, b:0.056124507219932715 iteration:500, loss:112.63263488241765, w:1.4783301496487404, b:0.05911848155564851 iteration:550, loss:112.63084277329152, w:1.478271326198889, b:0.06211132621281753 iteration:600, loss:112.62905201629816, w:1.4782125249441473, b:0.06510304161768705 iteration:650, loss:112.62726261041726, w:1.4781537458761407, b:0.06809362819634364 iteration:700, loss:112.62547455462945, w:1.478094988986498, b:0.07108308637471301 iteration:750, loss:112.62368784791623, w:1.4780362542668513, b:0.07407141657856024 iteration:800, loss:112.62190248925954, w:1.4779775417088346, b:0.07705861923348975 iteration:850, loss:112.62011847764245, w:1.4779188513040864, b:0.08004469476494522 iteration:900, loss:112.61833581204849, w:1.477860183044248, b:0.0830296435982101 iteration:950, loss:112.61655449146214, w:1.4778015369209634, b:0.08601346615840703 After1000iterations,b = 0.08893651996682016, w = 1.4777440851889796, error = 112.61481010123588