先来看最基础的二分类模型的混淆矩阵
Prediction | |||
Positive | Negative | ||
Target | True | 真阳性 - TP | 假阴性 - FN |
False | 假阳性 - FP | 真阴性 - TN |
根据这个矩阵,可以计算得到分类模型的几个指标:
准确率 | 预测正确样本在所有样本中的占比 | |
查准率 (精确度) |
真阳性样本在所有预测为阳性的样本 中的占比 |
|
查全率 (召回率) |
真阳性样本在所有真实为阳性的样本 中的占比,自动驾驶的目标检测因为 在遗漏目标时有巨大的安全隐患, 所以要求接近 100% 的召回率 |
|
Fβ-score | 查准率和查全率的调和平均数, β 表示了查准率的权值 (查全率的权值为 1), 比较常用的是: F1-score、F2-score、F0.5-score |
对于多分类模型,以 dog 为正样本 (其它分类为负样本),可得到如下的混淆矩阵:
Prediction | |||||
cat | dog | car | student | ||
Target | cat | FP | |||
dog | FN | TP | FN | FN | |
car | FP | ||||
student | FP |
由此可知,对于每一个类别,都可以由混淆矩阵各计算出 Precision、Recall
代码实现
使用一个类来创建混淆矩阵,并需要多个类方法来求解 Accuracy、Precision、Recall、Fβ-score
- _div:类方法,定义了防止除零、对结果保留 4 位小数的除法,具体计算为:
- __init__:根据 pred 和 target (支持 numpy 和 torch,不支持 list,在传参后调用 flatten 函数展平成行向量) 计算混淆矩阵
- _tp:使用 property 管理的类变量,每次访问都进行一次计算,对应每个类别的 TP
- accuracy、precision、recall:使用 property 管理的类变量,每次访问都进行一次计算
- f_score:可指定 β 的值,在此避免了 precision、recall 的二次计算
- eval:输出 Accuracy、Precision、Recall、F-Score 构成的字典
- __add__:支持 Crosstab 类的累加
import numpy as np
class Crosstab:
_div = lambda self, a, b, decimal=4, eps=1e-5: np.round(a / (b + eps), decimal)
def __init__(self, pred, target, classes=2):
assert all('int' in str(y.dtype) for y in (pred, target)), 'Only integer can be used to represent categories'
pred, target = map(lambda x: x.flatten(), (pred, target))
self._data = np.bincount(pred + classes * target, minlength=classes ** 2).reshape([classes] * 2)
_tp = property(fget=lambda self: np.diag(self._data))
accuracy = property(fget=lambda self: self._div(self._tp.sum(), self._data.sum()))
precision = property(fget=lambda self: self._div(self._tp, self._data.sum(axis=0)))
recall = property(fget=lambda self: self._div(self._tp, self._data.sum(axis=1)))
def f_score(self, beta=1.):
alpha = beta ** 2
precision, recall = self.precision, self.recall
return self._div((1 + alpha) * precision * recall, alpha * precision + recall)
def eval(self, beta=1.):
return {'Accuracy': self.accuracy, 'Precision': self.precision,
'Recall': self.recall, f'F{beta:.1f}-Score': self.f_score(beta)}
def __add__(self, other):
if not isinstance(other, type(self)):
types = tuple(map(lambda x: type(x).__name__, (self, other)))
raise TypeError(f'unsupported operand type(s) for +: \'{types[0]}\' and \'{types[1]}\'')
self._data += other._data
return self
def __str__(self):
return str(self._data)
__repr__ = __str__