tensorflow-保存与读取使用模型

1、MNIST是深度学习的经典入门demo,他是由6万张训练图片和1万张测试图片构成的,每张图片都是2828大小(如下图),而且都是黑白色构成(这里的黑色是一个0-1的浮点数,黑色越深表示数值越靠近1),这些图片是采集的不同的人手写从0到9的数字。
下面先训练识别数字模型
再保存模型
最后,读取保存的模型,对数字图片进行识别。

2、保存模型

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Feb  3 20:28:26 2019

@author: myhaspl
"""
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)

import tensorflow as tf
import os

x=tf.placeholder(tf.float32,[None,784])

w=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))

y=tf.nn.softmax(tf.matmul(x,w)+b)
y_=tf.placeholder(tf.float32,[None,10])
cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))

train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver()
for i in range(1000):
    sampleX,sampleY=mnist.train.next_batch(100)
    sess.run(train_step,feed_dict={x:sampleX,y_:sampleY})

print("训练完成")
print("保存生成模型...")
model_dir="mnist_model"
model_name="ml1"
if not os.path.exists(model_dir):
    os.mkdir(model_dir)

saver.save(sess,os.path.join(model_dir,model_name))
print("保存生成模型成功")  
训练完成
保存生成模型...
保存生成模型成功
[root@VM03centos learn]# ls mnist_model 
checkpoint  ml1.data-00000-of-00001  ml1.index  ml1.meta
[root@VM03centos learn]# ls MNISTdata 
t10k-images-idx3-ubyte.gz  t10k-labels-idx1-ubyte.gz  train-images-idx3-ubyte.gz  train-labels-idx1-ubyte.gz
[root@VM03centos learn]# 

读取数字识别模型,对某个数字图像进行识别

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Feb  3 20:28:26 2019

@author: myhaspl
"""
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)

import tensorflow as tf

x=tf.placeholder(tf.float32,[None,784])

w=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))

y=tf.nn.softmax(tf.matmul(x,w)+b)
y_=tf.placeholder(tf.float32,[None,10])
cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))

train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver()

print("读取模型...")
saver.restore(sess,"mnist_model/ml1")
print("读取模型完成")
print("根据模型进行计算...")
img=mnist.test.images[5]
result=sess.run(y,feed_dict={x:img.reshape(1,784)})
print("预测输出结果:{}".format(result))
print("预测结果:{}".format(result.argmax()))
print("实际结果:{}".format(mnist.test.labels[5].argmax()))
读取模型...
INFO:tensorflow:Restoring parameters from mnist_model/ml1
读取模型完成
根据模型进行计算...
预测输出结果:[[1.8999807e-06 9.8351490e-01 3.0815993e-03 4.3848301e-03 4.1427880e-05
  1.6864968e-04 7.6594086e-05 4.5587993e-03 3.2991443e-03 8.7222963e-04]]
预测结果:1
实际结果:1

猜你喜欢

转载自blog.51cto.com/13959448/2348859