tensorflow 学习笔记(4)-basic_example

basic_example : nearest neighbor algorithm

# -*- coding: utf-8 -*-
"""
Created on Tue Jun 20 19:26:25 2017
@author: wu
"""
# 引入模块
from __future__ import print_function
import tensorflow as tf
import numpy as np

#下载mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)

#获取训练数据和测试数据
Xtr, Ytr = mnist.train.next_batch(5000)
Xte, Yte = mnist.train.next_batch(200)

#TensorFlow的数据图的输入
xtr = tf.placeholder("float", [None, 784])
xte = tf.placeholder("float", [784])
#use L1 distance
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices = 1)
pred  = tf.arg_min(distance, 0)

accuracy  = 0.
init = tf.global_variables_initializer()

#Launch graph
with tf.Session() as sess:
    sess.run(init)
    for i in range(len(Xte)):#测试集图片迭代测试
        nn_index = sess.run(pred, feed_dict = {xtr: Xtr, xte: Xte[i, :]})
        print("Test", i, "Prediction: ", np.argmax(Ytr[nn_index]),"True class: ",np.argmax(Yte[i]))
        #如果测试集图片和训练集图片一样,计算精确率
        if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
            accuracy += 1./len(Xte)
    print("Done!")
    print("Accuracy: ",accuracy)

部分运行结果截图

这里写图片描述

猜你喜欢

转载自blog.csdn.net/wchzh2015/article/details/73554953