tensorflow线性问题

import numpy as np
import matplotlib.pyplot as plt 

# Default parameters for plots
plt.rcParams['font.size'] = 16
plt.rcParams['font.family'] = ['STKaiti']
plt.rcParams['axes.unicode_minus'] = False
def compute_error(b, w, data):
    totolError = 0
    for i in range(0,len(data)):
        x = data[i,0]
        y = data[i,1]
        totolError += (w * x + b - y)**2
    return totolError / (len(data))


def gradient_descend(b, w, learning_rate, iterations, data):
    new_b = b
    new_w = w
    losses = []
    for i in range(0,iterations):
        loss = compute_error(new_b, new_w, data)
        losses.append(loss)
        new_b,new_w = compute_gradient(new_b,new_w,np.array(data), learning_rate)
        
        if i%50 == 0: # 打印误差和实时的 w,b 值
             print(f"iteration:{i}, loss:{compute_error(new_b, new_w, data)}, w:{new_w}, b:{new_b}")

    return [new_b ,new_w],losses


def run():
    data = np.genfromtxt("data.csv",delimiter = ',')
    initial_b = 0
    initial_w = 0
    learning_rate = 0.0001
    iterations = 1000
    print('Starting gradient descent at b = {0}, w = {1}, error = {2}'.format(initial_b,initial_w,compute_error(initial_b,initial_w, data) ))
   # 训练优化 1000 次,返回最优 w*,b*和训练 Loss 的下降过程
    [b,w],losses = gradient_descend(initial_b, initial_w,learning_rate ,iterations, data)
    
    print('running...')
    
    [b,w],losses = gradient_descend(initial_b, initial_w,learning_rate ,iterations, data)
    print('After{0}iterations,b = {1}, w = {2}, error = {3}'.format(iterations,b,w,compute_error(b, w, data) ))
    x = [i for i in range(1000)]
    plt.plot(x,losses,color='C1', label = '均方差')
    plt.xlabel('Epoch')
    plt.ylabel('ERROR')
    plt.legend()    
    plt.show()


if __name__ == '__main__':
    run()

输出: 

Starting gradient descent at b = 0, w = 0, error = 5565.107834490552
iteration:0, loss:1484.5865573886724, w:0.7370702973620529, b:0.014547010110780006
iteration:50, loss:112.64882487923404, w:1.478860560860723, b:0.03213199290371721
iteration:100, loss:112.64702055487857, w:1.4788015372775043, b:0.0351350200320317
iteration:150, loss:112.64521759187215, w:1.4787425359649087, b:0.03813691406602341
iteration:200, loss:112.64341598918755, w:1.4786835569145336, b:0.04113767543322853
iteration:250, loss:112.64161574579843, w:1.4786246001179788, b:0.04413730456102194
iteration:300, loss:112.63981686067933, w:1.4785656655668482, b:0.047135801876617194
iteration:350, loss:112.63801933280529, w:1.478506753252748, b:0.05013316780706675
iteration:400, loss:112.63622316115239, w:1.4784478631672875, b:0.053129402779261856
iteration:450, loss:112.63442834469728, w:1.4783889953020797, b:0.056124507219932715
iteration:500, loss:112.63263488241765, w:1.4783301496487404, b:0.05911848155564851
iteration:550, loss:112.63084277329152, w:1.478271326198889, b:0.06211132621281753
iteration:600, loss:112.62905201629816, w:1.4782125249441473, b:0.06510304161768705
iteration:650, loss:112.62726261041726, w:1.4781537458761407, b:0.06809362819634364
iteration:700, loss:112.62547455462945, w:1.478094988986498, b:0.07108308637471301
iteration:750, loss:112.62368784791623, w:1.4780362542668513, b:0.07407141657856024
iteration:800, loss:112.62190248925954, w:1.4779775417088346, b:0.07705861923348975
iteration:850, loss:112.62011847764245, w:1.4779188513040864, b:0.08004469476494522
iteration:900, loss:112.61833581204849, w:1.477860183044248, b:0.0830296435982101
iteration:950, loss:112.61655449146214, w:1.4778015369209634, b:0.08601346615840703
running...
iteration:0, loss:1484.5865573886724, w:0.7370702973620529, b:0.014547010110780006
iteration:50, loss:112.64882487923404, w:1.478860560860723, b:0.03213199290371721
iteration:100, loss:112.64702055487857, w:1.4788015372775043, b:0.0351350200320317
iteration:150, loss:112.64521759187215, w:1.4787425359649087, b:0.03813691406602341
iteration:200, loss:112.64341598918755, w:1.4786835569145336, b:0.04113767543322853
iteration:250, loss:112.64161574579843, w:1.4786246001179788, b:0.04413730456102194
iteration:300, loss:112.63981686067933, w:1.4785656655668482, b:0.047135801876617194
iteration:350, loss:112.63801933280529, w:1.478506753252748, b:0.05013316780706675
iteration:400, loss:112.63622316115239, w:1.4784478631672875, b:0.053129402779261856
iteration:450, loss:112.63442834469728, w:1.4783889953020797, b:0.056124507219932715
iteration:500, loss:112.63263488241765, w:1.4783301496487404, b:0.05911848155564851
iteration:550, loss:112.63084277329152, w:1.478271326198889, b:0.06211132621281753
iteration:600, loss:112.62905201629816, w:1.4782125249441473, b:0.06510304161768705
iteration:650, loss:112.62726261041726, w:1.4781537458761407, b:0.06809362819634364
iteration:700, loss:112.62547455462945, w:1.478094988986498, b:0.07108308637471301
iteration:750, loss:112.62368784791623, w:1.4780362542668513, b:0.07407141657856024
iteration:800, loss:112.62190248925954, w:1.4779775417088346, b:0.07705861923348975
iteration:850, loss:112.62011847764245, w:1.4779188513040864, b:0.08004469476494522
iteration:900, loss:112.61833581204849, w:1.477860183044248, b:0.0830296435982101
iteration:950, loss:112.61655449146214, w:1.4778015369209634, b:0.08601346615840703
After1000iterations,b = 0.08893651996682016, w = 1.4777440851889796, error = 112.61481010123588

Guess you like

Origin blog.csdn.net/intmain_S/article/details/120782385