人工智能实战第二次作业——反向传播代码实现
1.作业要求
1.1使用Python实现双变量的反向线性传播
1.2通过梯度下降求解参数w和b的最终值
1.3具体要求:线性反向传播
2.python代码实现
2.1symbol库自动求导
from sympy import *
def derivatede_db(tw,tb):
w=symbols("w")
b=symbols("b")
x_symbol=symbols("x")
y_symbol=symbols("y")
x=2*w+3*b
y=2*b+1
z=x_symbol*y_symbol
dz=diff(x,b)*diff(z,x_symbol)+diff(y,b)*diff(z,y_symbol)
dz=dz.subs('w',tw)
dz=dz.subs('b',tb)
dz=dz.subs('x',x.subs('w',tw).subs('b',tb))
dz=dz.subs('y',y.subs('w',tw).subs('b',tb))
return dz
def derivatede_dw(tw,tb):
w=symbols("w")
b=symbols("b")
x_symbol=symbols("x")
y_symbol=symbols("y")
x=2*w+3*b
y=2*b+1
z=x_symbol*y_symbol
dz=diff(x,w)*diff(z,x_symbol)+diff(y,w)*diff(z,y_symbol)
dz=dz.subs('w',tw)
dz=dz.subs('b',tb)
dz=dz.subs('x',x.subs('w',tw).subs('b',tb))
dz=dz.subs('y',y.subs('w',tw).subs('b',tb))
return dz
2.2非迭代反向传播计算
/*
** 非迭代反向传播
*/
w=3.00000
b=4.00000
z_true=150.00000
dz=1.00
db=0.00
dw=0.00
dz_db=derivatede_db(3,4)
dz_dw=derivatede_dw(3,4)
while(dz>=1e-5):
x=2*w+3*b
y=2*b+1
z=x*y
dz=z-z_true
db=(dz/2)/dz_db
dw=(dz/2)/dz_dw
print("w=%f,,b=%f,z=%f,delta_z=%f,delta_b=%f,delta_w=%f"%(w,b,z,dz,db,dw))
w=w-dw
b=b-db
print("done!")
print("final_b=%f"%b)
print("final_w=%f"%w)
print("final_z=%f"%z)
/*
** 运算结果
*/
w=3.000000,,b=4.000000,z=162.000000,delta_z=12.000000,delta_b=0.095238,delta_w=0.333333
w=2.666667,,b=3.904762,z=150.181406,delta_z=0.181406,delta_b=0.001440,delta_w=0.005039
w=2.661628,,b=3.903322,z=150.005526,delta_z=0.005526,delta_b=0.000044,delta_w=0.000154
w=2.661474,,b=3.903278,z=150.000170,delta_z=0.000170,delta_b=0.000001,delta_w=0.000005
w=2.661469,,b=3.903277,z=150.000005,delta_z=0.000005,delta_b=0.000000,delta_w=0.000000
done!
final_b=3.903277
final_w=2.661469
final_z=150.000005
2.3迭代式反向传播
/*
** 迭代式反向传播
*/
w=3.00000
b=4.00000
z_true=150.00000
dz=1.00
db=0.00
dw=0.00
while(dz>=1e-5):
x=2*w+3*b
y=2*b+1
z=x*y
dz=z-z_true
db=(dz/2)/derivatede_db(w,b)
dw=(dz/2)/derivatede_dw(w,b)
print("w=%f,,b=%f,z=%f,delta_z=%f,delta_b=%f,delta_w=%f"%(w,b,z,dz,db,dw))
w=w-dw
b=b-db
print("done!")
print("final_b=%f"%b)
print("final_w=%f"%w)
print("final_z=%f"%z)
/*
***运算结果
*/
w=3.000000,,b=4.000000,z=162.000000,delta_z=12.000000,delta_b=0.095238,delta_w=0.333333
w=2.666667,,b=3.904762,z=150.181406,delta_z=0.181406,delta_b=0.001499,delta_w=0.005148
w=2.661519,,b=3.903263,z=150.000044,delta_z=0.000044,delta_b=0.000000,delta_w=0.000001
w=2.661517,,b=3.903263,z=150.000000,delta_z=0.000000,delta_b=0.000000,delta_w=0.000000
done!
final_b=3.903263
final_w=2.661517
final_z=150.000000
3.结果比较
3.1迭代式反向传播的下降速率更加高
3.2迭代式反向传播的下降精度更高