[ MOOC课程学习 ] 人工智能实践:Tensorflow笔记_CH6_1 输入手写数字输出识别结果

输入手写数字输出识别结果

实现断点续训
输入真实图片,输出预测结果

  1. 实现断点续训,在 mnist_backward.py 里加入三行代码即可:

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
    
            # 实现断点续训 ----------------------------------------
            ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            # ----------------------------------------------------
    
            for i in range(STEPS):
                xs, ys = mnist.train.next_batch(BATCH_SIZE)
                _, loss_v, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
                if i % 1000 == 0:
                    print('After %d training steps, loss on training batch is %g.' % (step, loss_v))
                    saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
    

    (1) tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)
    该函数表示如果断点文件夹中包含有效断点状态文件,则返回该文件。
    参数说明:
    checkpoint_dir:表示存储断点文件的目录
    latest_filename=None:断点文件的可选名称,默认为“checkpoint”
    (2) saver.restore(sess, ckpt.model_checkpoint_path)
    该函数表示恢复当前会话,将 ckpt 中的值赋给 w 和 b。
    参数说明:
    sess:表示当前会话,之前保存的结果将被加载入这个会话
    ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看 checkpoint 文件,看看最新的是谁,叫做什么。

  2. 输入真实图片,输出预测结果:
    mnist_forward.py 和 mnist_backward.py 、mnist_test.py不变,增加一个mnist_app.py
    模型的要求是黑(0)底白(255)字,但输入的图是白底黑字,所以需要对每个像素点的值改为 255 减去原值以得到互补的反色。

    
    # coding:utf-8
    
    import tensorflow as tf
    import numpy as np
    from PIL import Image
    import mnist_forward
    import mnist_backward
    
    def restore_model(testPicArr):
        # 创建一个默认图,在图中执行相应操作
        with tf.Graph().as_default() as g:
            x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
            y = mnist_forward.forward(x, None)
            preValue = tf.argmax(y, 1)
    
            ema = tf.train.ExponentialMovingAverage(mnist_backward.EMA_DECAY)
            ema_restore = ema.variables_to_restore()
            saver = tf.train.Saver(ema_restore)
    
            with tf.Session() as sess:
                # 通过checkpoint文件定位到最新保存的模型
                ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    preValue_val = sess.run(preValue, feed_dict={x:testPicArr})
                    return preValue_val
                else:
                    print('No checkpoint file found.')
                    return -1
    
    
    
    # 输入图片预处理函数
    
    def pre_pic(picName):
        img = Image.open(picName)
        # 用消除锯齿的方式 resize
        reIm = img.resize((28, 28), Image.ANTIALIAS)
        # 转变为灰度图
        im_arr = np.array(reIm.convert('L'))
        # 设定合理的阈值,对图片做二值化处理(这样以滤掉噪声,另外调试中可适当调节阈值)
        threshold = 50
        for i in range(28):
            for j in range(28):
                im_arr[i][j] = 255 - im_arr[i][j]
                if (im_arr[i][j] < threshold):
                    im_arr[i][j] = 0
                else:
                    im_arr[i][j] = 255
    
        nm_arr = im_arr.reshape([1, 784])
        nm_arr = nm_arr.astype(np.float32)
        im_ready = np.multiply(nm_arr, 1.0/255.0)
        return im_ready
    
    def application():
        testNum = input('input the number of test pictures:')
        for i in range(int(testNum)):
            testPic = input('input the path of test picture:')
            # 对手写数字图片做预处理
            testPicArr = pre_pic(testPic)
            # 将符合神经网络输入要求的图片喂给复现的神经网络模型,输出预测值
            preValue = restore_model(testPicArr)
            print('The prediction number is: ', preValue)
    
    
    if __name__ == '__main__':
        application()

    如果没有with tf.Graph().as_default() as g: 会报错:

    input the number of test pictures:10
    input the path of test picture:pic/0.png
    2018-07-19 10:19:56.079652: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
    2018-07-19 10:19:56.317172: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:892] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
    2018-07-19 10:19:56.317466: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030] Found device 0 with properties: 
    name: GeForce 940MX major: 5 minor: 0 memoryClockRate(GHz): 1.2415
    pciBusID: 0000:01:00.0
    totalMemory: 1.96GiB freeMemory: 1.94GiB
    2018-07-19 10:19:56.317483: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce 940MX, pci bus id: 0000:01:00.0, compute capability: 5.0)
    The prediction number is:  [0]
    input the path of test picture:pic/1.png   
    2018-07-19 10:20:12.682054: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce 940MX, pci bus id: 0000:01:00.0, compute capability: 5.0)
    2018-07-19 10:20:12.689655: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_7/ExponentialMovingAverage not found in checkpoint
    2018-07-19 10:20:12.690771: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_4/ExponentialMovingAverage not found in checkpoint
    2018-07-19 10:20:12.691076: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_5/ExponentialMovingAverage not found in checkpoint
    2018-07-19 10:20:12.691219: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_6/ExponentialMovingAverage not found in checkpoint
    Traceback (most recent call last):
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call
        return fn(*args)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
        status, run_metadata)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
        c_api.TF_GetCode(self.status.status))
    tensorflow.python.framework.errors_impl.NotFoundError: Key Variable_7/ExponentialMovingAverage not found in checkpoint
         [[Node: save_1/RestoreV2_7 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_7/tensor_names, save_1/RestoreV2_7/shape_and_slices)]]
         [[Node: save_1/RestoreV2_6/_3 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_22_save_1/RestoreV2_6", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
    
    During handling of the above exception, another exception occurred:
    
    Traceback (most recent call last):
      File "mnist_application.py", line 58, in <module>
        application()
      File "mnist_application.py", line 53, in application
        preValue = restore_model(testPicArr)
      File "mnist_application.py", line 21, in restore_model
        saver.restore(sess, ckpt.model_checkpoint_path)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1666, in restore
        {self.saver_def.filename_tensor_name: save_path})
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 889, in run
        run_metadata_ptr)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1120, in _run
        feed_dict_tensor, options, run_metadata)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1317, in _do_run
        options, run_metadata)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1336, in _do_call
        raise type(e)(node_def, op, message)
    tensorflow.python.framework.errors_impl.NotFoundError: Key Variable_7/ExponentialMovingAverage not found in checkpoint
         [[Node: save_1/RestoreV2_7 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_7/tensor_names, save_1/RestoreV2_7/shape_and_slices)]]
         [[Node: save_1/RestoreV2_6/_3 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_22_save_1/RestoreV2_6", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
    
    Caused by op 'save_1/RestoreV2_7', defined at:
      File "mnist_application.py", line 58, in <module>
        application()
      File "mnist_application.py", line 53, in application
        preValue = restore_model(testPicArr)
      File "mnist_application.py", line 14, in restore_model
        saver = tf.train.Saver(ema_restore)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1218, in __init__
        self.build()
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1227, in build
        self._build(self._filename, build_save=True, build_restore=True)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1263, in _build
        build_save=build_save, build_restore=build_restore)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 751, in _build_internal
        restore_sequentially, reshape)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 427, in _AddRestoreOps
        tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 267, in restore_op
        [spec.tensor.dtype])[0])
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1021, in restore_v2
        shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
        op_def=op_def)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
        op_def=op_def)
      File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
        self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access
    
    NotFoundError (see above for traceback): Key Variable_7/ExponentialMovingAverage not found in checkpoint
         [[Node: save_1/RestoreV2_7 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_7/tensor_names, save_1/RestoreV2_7/shape_and_slices)]]
         [[Node: save_1/RestoreV2_6/_3 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_22_save_1/RestoreV2_6", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
    
    

猜你喜欢

转载自blog.csdn.net/ranmw1129/article/details/81109287