关于准确率的详细介绍

主要是参考了《tensorflow》pdf的文档,首先使用的是一个tf.argmax的函数,它会找出某个tensor对象在某一个维度上最大值的索引。因为一般标签都是0,1,因此最大值1所在的索引位置就是类别标签,比如tf.argmax(y,1)返回的是模型对于任意输入x预测的标签值,而tf.argmax(y_,1)代表正确的标签,我们可以用tf.equal来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)。

correct_prediction=tf.equal(tf.argmax(y,1),tf.agrmax(y_,1))#我大概明白了,它把所有的1的位置找出来,和预测的结果做对比

这行代码会给我们一组布尔值。为例确定正确预测的比例,我们可以把布尔值转化成浮点数,然后取平均值。例如,[True,False,True,True]会变成[1,0,1,1],取平均值后得到0.75.

accuary=tf.reduce_mean(tf.cast(correct_prediction,'float'))
#先转换成float的形式,在对其结果取平均值

最后我们计算所学习到的模型在测试数据集上面的正确率。

pring(sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels}))

猜你喜欢

转载自blog.csdn.net/hanrui4721960/article/details/80726392