tensorflow实战——单变量线性回归

代码示例

#  _*_ encoding:utf-8  _*_
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

trX = np.linspace(-1, 1, 101)   #在1和1之间创建一个101点的线性空间
# 注意 在元组 trX.shape 前面加了个 *,表示可变参数,可传入 0个或多个参数,很常用
trY = 2 * trX + np.random.randn(*trX.shape) * 0.4 + 0.2 # create a y value which is approximately linear but with some random noise
plt.figure()
plt.scatter(trX,trY)  #散点图
plt.plot(trX, .2 + 2 * trX)  #画直线
# 首先创建一个变量来保存x 和y的数值
X = tf.placeholder("float", name="X") # create symbolic variables
Y = tf.placeholder("float", name = "Y")
# 首先,我们为模型声明一个 name_scope,于是便可以将这个scope的变量和操作视为一个同质实体
# 在这个 scope 中,我们先定义一个线性方程,用 x 乘以权重(斜率)加上偏差
# 然后定义一个存放权重(斜率)和偏差的共享变量。这些变量在迭代计算过程中不断的发生比变化,
# 最后将定义的 model 的返回值赋给 y_model
with tf.name_scope("Model"):
    def model(X, w, b):
        return tf.multiply(X, w) + b # We just define the line as X*w + b0
    w = tf.Variable(-1.0, name="b0") # create a shared variable
    b = tf.Variable(-2.0, name="b1") # create a shared variable
    y_model = model(X, w, b)
# 在损失函数(cost function)中,同样先创建一个(scope)来包含所有的操作
# 同样 Y 在前面也已经创建符号变量,使用之前创建的 y_model 来计算 y轴值,使用均方误差(sqr error)
with tf.name_scope("CostFunction"):
    cost = (tf.pow(Y-y_model, 2)) # use sqr error for cost function
# 定义一个 Optimizer,这里采用的优化方法是梯度下降,步长设为0.05 为经验值
train_op = tf.train.GradientDescentOptimizer(0.05).minimize(cost)
# 创建一个会话,并将初始化的变量保存起来,便于TensorBoard中查看本例中我们将每次迭代后的最后一个误差结果作为一个标量保存起来。
# 同理,我们也需要将 TensorFlow 生成的图结构奥存起来用于之后查看
sess = tf.Session()
init = tf.initialize_all_variables()
tf.train.write_graph(sess.graph, '/home/ubuntu/linear','graph.pbtxt')
cost_op = tf.summary.scalar("loss", cost)
merged = tf.summary.merge_all()
sess.run(init)
writer = tf.summary.FileWriter('/home/ubuntu/linear', sess.graph)
# 在模型训练阶段,设置迭代 100 次,每次我们通过将样本输入模型,进行梯度下降。
# 迭代之后,绘制出模型曲线,并将误差存入 summary
for i in range(100):
    for (x, y) in zip(trX, trY):
        sess.run(train_op, feed_dict={X: x, Y: y})
        summary_str = sess.run(cost_op, feed_dict={X: x, Y: y})
        writer.add_summary(summary_str, i)
    b0temp=b.eval(session=sess)
    b1temp=w.eval(session=sess)
    plt.plot (trX, b0temp + b1temp * trX )


print(sess.run(w)) # Should be around 2
print(sess.run(b)) #Should be around 0.2


plt.scatter(trX,trY)
plt.plot (trX, sess.run(b) + trX * sess.run(w))
plt.show()

 输出结果:

也可以在tensorboard中查看数据结果。

tensorboard的启用,需要指定日志目录,执行以下命令:

tensorboard --logdir=.(路径)

tensorboard会加载日志目录中的事件和图形文件,并监听6006口。你可以在浏览器中输入“localhost:6006”,然后就能在浏览器中看到类似tensorboard的仪表盘。

猜你喜欢

转载自blog.csdn.net/zqzq19950725/article/details/88287233