人工智能实战第二次作业—15041025—陶恺

人工智能实战第二次作业——反向传播代码实现

1.作业要求

1.1使用Python实现双变量的反向线性传播
1.2通过梯度下降求解参数w和b的最终值
1.3具体要求:线性反向传播

2.python代码实现

2.1symbol库自动求导

from sympy import *
def derivatede_db(tw,tb):
    w=symbols("w")
    b=symbols("b")
    x_symbol=symbols("x")
    y_symbol=symbols("y")
    x=2*w+3*b
    y=2*b+1
    z=x_symbol*y_symbol
    dz=diff(x,b)*diff(z,x_symbol)+diff(y,b)*diff(z,y_symbol)
    dz=dz.subs('w',tw)
    dz=dz.subs('b',tb)
    dz=dz.subs('x',x.subs('w',tw).subs('b',tb))
    dz=dz.subs('y',y.subs('w',tw).subs('b',tb))
    return dz
def derivatede_dw(tw,tb):
    w=symbols("w")
    b=symbols("b")
    x_symbol=symbols("x")
    y_symbol=symbols("y")
    x=2*w+3*b
    y=2*b+1
    z=x_symbol*y_symbol
    dz=diff(x,w)*diff(z,x_symbol)+diff(y,w)*diff(z,y_symbol)
    dz=dz.subs('w',tw)
    dz=dz.subs('b',tb)
    dz=dz.subs('x',x.subs('w',tw).subs('b',tb))
    dz=dz.subs('y',y.subs('w',tw).subs('b',tb))
    return dz

2.2非迭代反向传播计算

/*
** 非迭代反向传播
*/
w=3.00000
b=4.00000
z_true=150.00000
dz=1.00
db=0.00
dw=0.00
dz_db=derivatede_db(3,4)
dz_dw=derivatede_dw(3,4)
while(dz>=1e-5):
    x=2*w+3*b
    y=2*b+1
    z=x*y
    dz=z-z_true
    db=(dz/2)/dz_db
    dw=(dz/2)/dz_dw
    print("w=%f,,b=%f,z=%f,delta_z=%f,delta_b=%f,delta_w=%f"%(w,b,z,dz,db,dw))
    w=w-dw
    b=b-db
print("done!")
print("final_b=%f"%b)
print("final_w=%f"%w)
print("final_z=%f"%z)
/*
** 运算结果
*/
w=3.000000,,b=4.000000,z=162.000000,delta_z=12.000000,delta_b=0.095238,delta_w=0.333333
w=2.666667,,b=3.904762,z=150.181406,delta_z=0.181406,delta_b=0.001440,delta_w=0.005039
w=2.661628,,b=3.903322,z=150.005526,delta_z=0.005526,delta_b=0.000044,delta_w=0.000154
w=2.661474,,b=3.903278,z=150.000170,delta_z=0.000170,delta_b=0.000001,delta_w=0.000005
w=2.661469,,b=3.903277,z=150.000005,delta_z=0.000005,delta_b=0.000000,delta_w=0.000000
done!
final_b=3.903277
final_w=2.661469
final_z=150.000005

2.3迭代式反向传播

/*
** 迭代式反向传播
*/
w=3.00000
b=4.00000
z_true=150.00000
dz=1.00
db=0.00
dw=0.00
while(dz>=1e-5):
    x=2*w+3*b
    y=2*b+1
    z=x*y
    dz=z-z_true
    db=(dz/2)/derivatede_db(w,b)
    dw=(dz/2)/derivatede_dw(w,b)
    print("w=%f,,b=%f,z=%f,delta_z=%f,delta_b=%f,delta_w=%f"%(w,b,z,dz,db,dw))
    w=w-dw
    b=b-db
print("done!")
print("final_b=%f"%b)
print("final_w=%f"%w)
print("final_z=%f"%z)
/*
***运算结果
*/
w=3.000000,,b=4.000000,z=162.000000,delta_z=12.000000,delta_b=0.095238,delta_w=0.333333
w=2.666667,,b=3.904762,z=150.181406,delta_z=0.181406,delta_b=0.001499,delta_w=0.005148
w=2.661519,,b=3.903263,z=150.000044,delta_z=0.000044,delta_b=0.000000,delta_w=0.000001
w=2.661517,,b=3.903263,z=150.000000,delta_z=0.000000,delta_b=0.000000,delta_w=0.000000
done!
final_b=3.903263
final_w=2.661517
final_z=150.000000

3.结果比较

3.1迭代式反向传播的下降速率更加高
3.2迭代式反向传播的下降精度更高

猜你喜欢

转载自www.cnblogs.com/tk7362/p/10535246.html