【tensorflow】MTCNN网络基本函数cal_accuracy()

cal_accuracy()用于精度计算。

import numpy as np
import tensorflow as tf
def cal_accuracy(cls_prob,label):
    '''
    :param cls_prob:
    :param label:
    :return:calculate classification accuracy for pos and neg examples only
    '''
    # get the index of maximum value along axis one from cls_prob
    # 0 for negative 1 for positive
    #按行返回cls_prob的最大值的索引,索引值为0或者1,索引等于0时
    #表示这个图片网络预测为非人脸;为1时网络预测这张图片为人脸
    pred = tf.argmax(cls_prob,axis=1)
    label_int = tf.cast(label,tf.int64)
    #tf.greater_equal()函数判断label_int是否大于等于0,返回True或者False
    #tf.where函数()返回True值对应的索引,即cond是pos样本和part样本对应的索引
    cond = tf.where(tf.greater_equal(label_int,0))
    picked = tf.squeeze(cond)
    # 获得pos样本和part样本的label
    label_picked = tf.gather(label_int,picked)
    #获得pos和part的预测值
    pred_picked = tf.gather(pred,picked)
    with tf.Session() as sess:
        print("pred:%s" % (sess.run(pred)))
        print("label_int:%s" % (sess.run(label_int)))
        print("cond:%s" % (sess.run(cond)))
        print("picked:%s" % (sess.run(picked)))
        print("label_picked:%s" % (sess.run(label_picked)))
        print("pred_picked:%s" % (sess.run(pred_picked)))
    #通过tf.equal()函数返回的True或者False值得到网络的预测值是否准确
    #将True和Flase转化为1和0求得平均值即得到准确率
    t = tf.cast(tf.equal(label_picked,pred_picked),tf.float32)
    accuracy_op = tf.reduce_mean(t)
    with tf.Session() as sess:
        print("t:%s" % (sess.run(t)))
        print("accuracy_op:%s" % sess.run(accuracy_op))
    return accuracy_op
cls_prob = tf.random_uniform([2,2],0,1,seed = 100)
with tf.Session() as sess:
    print("cls_prob:%s"%(sess.run(cls_prob)))
label = np.array([1, 0])
cal_accuracy(cls_prob,label)

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/zhouzongzong/article/details/94735785