tensorflow模型保存与可视化

本例以数据的二分类为例,实现了模型的保存、加载、以及tensorboard的可视化。

1、实现功能

对如下数据进行二分类,[[1.,1.2],[2.,2.3],[3.,3.5],[4.,4.1],[1.,0.8],[2.,1.3],[3.,2.5],[4.,3.1]],如图所示。
这里写图片描述
数据以Y=X为分界线,上部分是1类,下部分是0类。

2、具体代码

代码分为两部分,一个是lt_save.py,主要实现了模型的训练与保存,一个是lt_load.py主要实现模型的加载。

2.1 lt_save.py

#coding:utf-8
from __future__ import division

import tensorflow as tf 
import numpy as np 
import os

X = np.array([[1.,1.2],[2.,2.3],[3.,3.5],[4.,4.1],[1.,0.8],[2.,1.3],[3.,2.5],[4.,3.1]])
Y = np.array([[0,1],[0,1],[0,1],[0,1],[1,0],[1,0],[1,0],[1,0]])


#-----------------------------前向过程-----------------------
x_input = tf.placeholder(tf.float32,shape = (None,2),name = "x_input")
y_input = tf.placeholder(tf.int16,shape = (None,2),name = "y_input")

with tf.variable_scope("layer_1"):
    w1 = tf.Variable(tf.random_normal([2,3],stddev=0.5))
    b1 = tf.Variable(tf.random_normal([3,]))
    a1 = tf.nn.relu(tf.nn.bias_add(tf.matmul(x_input,w1),b1))

with tf.variable_scope("layer_2"):
    w2 = tf.Variable(tf.random_normal([3,4],stddev=1))
    b2 = tf.Variable(tf.random_normal([4,]))
    a2 = tf.nn.relu(tf.nn.bias_add(tf.matmul(a1,w2),b2))

with tf.variable_scope("output"):
    w3 = tf.Variable(tf.random_normal([4,2],stddev=2))
    b3 = tf.Variable(tf.random_normal([2,]))
    y =  tf.nn.bias_add(tf.matmul(a2,w3),b3)
    y_prediction = tf.arg_max(y,1,name = "prediction")




with tf.variable_scope("cost"):
    cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(y,tf.arg_max(y_input,1)))
    tf.scalar_summary("loss",cost)  #写入日志文件

with tf.variable_scope("accuracy"):
    corrcet_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_input,1))
    acc = tf.reduce_mean(tf.cast(corrcet_prediction,tf.float32))    #这里要将corrcet_prediction转换为浮点型
    tf.scalar_summary("acc",acc)   #写入日志文件

global_step = tf.Variable(0,trainable=False)

train_step = tf.train.AdamOptimizer(0.005).minimize(cost, global_step=global_step) #定义最小化cost函数操作
#------------------------------------------------------------

init = tf.initialize_all_variables()  #初始化变量操作
merged = tf.merge_all_summaries()  #整理所有的日志文件

STEPS = 3001
SAVE_PATH = "./model/"
MODEL_NAME = "lt_model.ckpt"
SUMMARY_PATH = "./summary"
batch_size = 4
data_size = len(X)
print "\ndata_size",data_size
saver = tf.train.Saver()
with tf.Session() as sess:


    summary_writer = tf.train.SummaryWriter(SUMMARY_PATH,tf.get_default_graph())

    sess.run(init)

    for i in range(STEPS):
        start = (i*batch_size)%data_size  
        # print "\nstart",start      
        end = min(start+batch_size,data_size)    #这里start 与end 必须要满足data_size能被batch_size整除
                                                 #这里也可以用yield进行迭代供给数据
        # print "\nend",end

        summary, _ = sess.run([merged,train_step],feed_dict={x_input:X[start:end],y_input:Y[start:end]})  #得到运行时的日志
        summary_writer.add_summary(summary,i)  #将所有日志写入文件
        if i%100 == 0:
            loss,accuracy,currect_step =  sess.run([cost,acc,global_step],feed_dict={x_input:X[start:end],y_input:Y[start:end]})
            print "step=",currect_step,"loss=",loss,"acc=",accuracy
        if i%3000 == 0:
            saver.save(sess,os.path.join(SAVE_PATH,MODEL_NAME),global_step=currect_step)  #保存模型
            print "model has been saved"


summary_writer.close()   #关闭日志文件
print "all done"
#查看可视化结果:
#tensorboard --logdir=./summary

运行lt_save.py,保存了模型与日志文件。

2.2 查看日志文件

若想查看日志内容,如loss的变化情况,输入:

tensorboard --logdir=./summary

然后会出现一个网址,点击,则可以进入tensorboard,查看保存日志,得到 loss, acc 的数据。如下图所示。
这里写图片描述
这里写图片描述
这里写图片描述

2.3 lt_load.py

这个代码实现了加载模型,并且对一个数据进行分类。

#coding:utf-8
import tensorflow as tf
import numpy as np 

X = np.array([[4.,4.5]])  #输入数据

graph = tf.Graph()
with graph.as_default():
    sess = tf.Session()
    with sess.as_default():
        saver = tf.train.import_meta_graph("./model/lt_model.ckpt-3001.meta")#加载图,这里保存了整个图的结构
        saver.restore(sess,"./model/lt_model.ckpt-3001")#加载模型,这里保存了每个变量的值

        x_input = graph.get_operation_by_name("x_input").outputs[0]#加载placeholder,这里计算节点为"x_input",其本身没有:0

        y_prediction = graph.get_operation_by_name("output/prediction").outputs[0]#从节点"output/prediction" 加载张量 y_prediction。注意.outputs[0]是指
                                                                                  #节点的第一个输出,即y_prediction
        pred = sess.run(y_prediction,feed_dict={x_input:X})#进行预测
        print "prediction is ",pred

输出结果为:prediction is [1]

3 不足之处

保存模型时没有设置checkpoints,应该是训练一定步数保存一次,也没有实现自动加载最新的模型,接下来应当实现。

猜你喜欢

转载自blog.csdn.net/liushui94/article/details/78309167