tensoflow 识别数字

tensoflow 的确抽象,不过也是很不错的

现在贴段识别文体的代码 有注释,慢慢体会

#样本的准备 http://yann.lecun.com/exdb/mnist/
#通过最近临域法来判断识别数字 待检测的图片和样本图片进行比较 k个图片中我们找到个相似度最大的
#当前图片描绘的是哪些点,就需要解析图片中解析的点,需要通过lable标签获得
#将当前的lable转换为具体的数字
#检测概率统计
import tensorflow as tf
import numpy as ny
import random
from tensorflow.examples.tutorials.mnist import input_data
#load data 第一个文字路径 第二个表示 第一个参数为1 其他为0
#完成测试图片和训练图片的距离计算
#如何根据knn 中k个最近的5张图片 和500张图片做差,在500张图片中找到4张最接近的测试图片
mnist = input_data.read_data_sets('MNIST_data',one_hot = True)
# 属性设置 总共训练图片 55000张
trainNum = 55000
#测试图片 10000
testNum = 10000
#训练的时候使用的图片500张
trainSize = 500
#测试图片5张
testSize = 5
#以下数据的分解
#训练数据的下标 生存了 trainSize 这么多个随机数 范围是0到trainNum 之间 replace=False 表示不可以重复
trainIndex = ny.random.choice(trainNum,trainSize,replace=False)
testIndex = ny.random.choice(testNum,testSize,replace=False)
#当前的训练数据
trainData = mnist.train.images[trainIndex]
#获取当前训练标签
trainLable = mnist.train.labels[trainIndex]
testData = mnist.train.images[testIndex]
testLable = mnist.train.labels[testIndex]
'''
如何计算两张图片的距离,可以用两张图片对应元素相减
'''
print('traindata.shape=',trainData.shape)
print('trainLable.shape=',trainLable.shape)
print('testData.shape=',testData.shape)
print('tetstLable.shape=',testLable.shape)

trainDataInput = tf.placeholder(shape=[None,784],dtype=tf.float32)
trainLableInput = tf.placeholder(shape=[None,10],dtype=tf.float32)
testDataInput = tf.placeholder(shape=[None,784],dtype=tf.float32)
tetstLableInput = tf.placeholder(shape=[None,10],dtype=tf.float32)

#两张图片的距离
#完成维度的转换
f1 = tf.expand_dims(testData,1)#维度的扩展
f2 = tf.subtract(trainDataInput,f1) #维度相减
f3 = tf.reduce_sum(tf.abs(f2),reduction_indices=2) #完成数据累加
f4 = tf.negative(f3) #取反
f5,f6 = tf.nn.top_k(f4,k=4) #选取f4中最大的4个值 对f3来说是最小的四个值
#f6 存储的是最近的4张图片的index,根据下标索引训练出标签
f7 = tf.gather(trainLableInput,f6)
#数字的获取
f8 = tf.reduce_sum(f7,reduction_indices=1)
#选取在某一个纬度上最大的值 并记录当前的x下标
f9 = tf.argmax(f8,dimension=1)
#所有的检测图片的最大值
with tf.Session() as sess:
    p1 = sess.run(f1,feed_dict={testDataInput:testData[0:5]})
    print('p1=',p1.shape)
    p2 = sess.run(f2,feed_dict={trainDataInput:trainData,testDataInput:testData[0:5]})
    print('p2=',p2.shape)
    p3 = sess.run(f3,feed_dict={trainDataInput:trainData,testDataInput:testData[0:5]})
    print('p3=',p3.shape)
    print('p3[0,0]=',p3[0,0])
    p4 = sess.run(f4,feed_dict={trainDataInput:trainData,testDataInput:testData[0:5]})
    print('p4=',p4.shape)
    print('p4[0,0]=',p4[0,0])
   #每一张测试图片分别对应4张对应的训练图片
    p5,p6 = sess.run((f5,f6),feed_dict={trainDataInput:trainData,testDataInput:testData[0:5]})
    print('p5=',p5.shape)
    print('p5[0,0]=',p5[0,0])
    print('p6=',p6.shape)
    print('p6[0,0]=',p6[0,0])
    p7 = sess.run(f7,feed_dict={trainDataInput:trainData,testDataInput:testData[0:5],trainLableInput:trainLable})
    print('p7=',p7.shape)
    print('p7[]=',p7)
    p8 = sess.run(f8,feed_dict={trainDataInput:trainData,testDataInput:testData[0:5],trainLableInput:trainLable})
    print('p8=',p8.shape)
    print('p8[]=',p8)
    p9 = sess.run(f9,feed_dict={trainDataInput:trainData,testDataInput:testData[0:5],trainLableInput:trainLable})
    print('p9=',p9.shape)
    print('p9[]=',p9)
    #找到测试标签中的所有内容
    p10 = ny.argmax(testLable[0:5],axis=1)
    print('p10[]=',p9)
j = 0
for i in range(0,5):
    if p10[i] == p9[i]:
        j = j+1
print(j)

猜你喜欢

转载自blog.csdn.net/renfujiang/article/details/83038747