tensorflow 自己手动实现的线性回归
#!/usr/bin/python
# -*- coding:utf-8 -*-
import tensorflow as tf
import os
# 第二个参数是默认值
tf.app.flags.DEFINE_integer("max_iter", 100, "迭代次数")
tf.app.flags.DEFINE_string("model_dir", "./tmp/ckpt/model", "模型路径")
tf.app.flags.DEFINE_string("summary_dir", "./tmp/test/", "graph路径")
tf.app.flags.DEFINE_string("checkpoint_dir", "./tmp/ckpt/checkpoint", "模型路径")
FLAGS=tf.app.flags.FLAGS
def mylineregression():
with tf.variable_scope("data"):
x=tf.random_normal([100,1],0.0,1.0)
y=tf.multiply(x,[[0.7]])+0.8
with tf.variable_scope("model"):
weight=tf.Variable(tf.random_normal([1,1],0.0,1.0))
bias=tf.Variable(0.0)
y_predict=tf.multiply(x,weight)+bias
with tf.variable_scope("loss"):
loss=tf.reduce_mean(tf.square(y-y_predict))
with tf.variable_scope("optimizer"):
train_op=tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init_value=tf.global_variables_initializer()
saver=tf.train.Saver()
tf.summary.scalar("losses",loss)
tf.summary.histogram("weight",weight)
tf.summary.histogram("bias",bias)
merged=tf.summary.merge_all()
with tf.Session() as sess:
sess.run(init_value)
filwriter=tf.summary.FileWriter(FLAGS.summary_dir, graph=sess.graph)
# print(sess.run([weight,bias]))
# 加载模型,覆盖变量的值
if os.path.exists(FLAGS.checkpoint_dir):
saver.restore(sess,FLAGS.model_dir)
for i in range(FLAGS.max_iter):
print("第%d次训练参数weight:%f,bias:%f"%(i,weight.eval(),bias.eval()))
# print(y_predict.eval())
summary=sess.run(merged)
filwriter.add_summary(summary,i)
sess.run(train_op)
tf.summary.FileWriter(FLAGS.summary_dir,graph=sess.graph)
# 保存模型
# saver.save(sess,"./tmp/ckpt/model")
return None
if __name__ == '__main__':
print("hello")
mylineregression()