15_模型加载、测试集

在这里插入图片描述


博文配套视频课程:24小时实现从零到AI人工智能


模型加载注意事项

  1. 模型加载前提是之前模型已经有保存
  2. 加载之后的模型应该是训练好的模型,在进行测试时是不需要在进行梯度下降降低误差的
  3. 训练的数据集必须要与测试的数据集特征值相同
  4. 测试集设置的name参数必须与模型保存时的name数量和通途上匹配

模型加载代码

# 完成图形图像的加载 (分辨率和尺寸相同)
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# 通过函数下载或者加载本地的图片资源,建议设置one_hot=True
mnist = input_data.read_data_sets("../data/input_data",one_hot=True)
# 训练的时候不会一次性加载55000,因此需要采用占位符
X = tf.placeholder(tf.float32,[None,784])
y_true = tf.placeholder(tf.float32,[None,10])
# 多少个连接就会有多少个权重weight   [None,784] dot [784,10]   ===> [None,10]
# 变量: 深度学习过程中可变的量
weight = tf.Variable(tf.random.normal([784,10]),name="weight")
bias = tf.Variable(tf.random.normal([10]),name="bias")
y_predict = tf.matmul(X,weight) + bias
# 被加载的模型无序在进行梯度下降
saver = tf.train.Saver()
# 通过会话连接graph图
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 判断模型是否存在
    import os
    if os.path.exists("../data/sess"):
        print('模型加载成功..............')
        # 用加载的sess,替换现有的sess
        saver.restore(sess,"../data/sess/point")
        for i in range(30):
            # 每一次获取指定数量的样本和样本的目标值
            x_test,y_test = mnist.test.next_batch(1)
            d = {X:x_test,y_true:y_test}
            # print(f'第{i+1}次训练的正确率为:{sess.run(accuracy,feed_dict=d)}')
            print(f'第{i + 1}次,真实值为:{tf.argmax(y_test,1).eval()},预测值为:{sess.run(tf.argmax(y_predict,1),feed_dict=d)}')
    else:
        print('模型不存在,或者路径有误!')

在这里插入图片描述

发布了128 篇原创文章 · 获赞 17 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/lsqzedu/article/details/102557918
今日推荐