Matriz de confusión e índice de evaluación de segmentación semántica: Acc CAcc MAcc loU MIoU FWMIoU
verificar
"""
refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py
"""
import numpy as np
__all__ = ['SegmentationMetric']
"""
预测
P N
P TP FN
真实
N FP TN
"""
class SegmentationMetric(object):
def __init__(self, numClass):
self.numClass = numClass
self.confusionMatrix = np.zeros((self.numClass,) * 2)
def genConfusionMatrix(self, imgPredict, imgLabel):
mask = (imgLabel >= 0) & (imgLabel < self.numClass)
label = self.numClass * imgLabel[mask] + imgPredict[mask]
count = np.bincount(label, minlength=self.numClass ** 2)
confusionMatrix = count.reshape(self.numClass, self.numClass)
return confusionMatrix
def addBatch(self, imgPredict, imgLabel):
assert imgPredict.shape == imgLabel.shape
self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel)
def pixelAccuracy(self):
acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
return acc
def classPixelAccuracy(self):
classAcc = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1)
return classAcc
def meanPixelAccuracy(self):
classAcc = self.classPixelAccuracy()
meanAcc = np.nanmean(classAcc)
return meanAcc
def IntersectionOverUnion(self):
intersection = np.diag(self.confusionMatrix)
union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
self.confusionMatrix)
IoU = intersection / union
return IoU
def meanIntersectionOverUnion(self):
IoU = self.IntersectionOverUnion()
mIoU = np.nanmean(IoU)
return mIoU
def FrequencyWeightedIntersectionOverUnion(self):
freq = np.sum(self.confusionMatrix, axis=1) / np.sum(self.confusionMatrix)
iu = np.diag(self.confusionMatrix) / (
np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) -
np.diag(self.confusionMatrix))
FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
return FWIoU
def reset(self):
self.confusionMatrix = np.zeros((self.numClass, self.numClass))
if __name__ == '__main__':
output = np.array([[[[1, 1, 0],
[2, 0, 1],
[2, 1, 0]],
[[2, 0, 2],
[0, 0, 1],
[1, 1, 2]],
[[1, 1, 0],
[0, 0, 2],
[1, 0, 2]]]
])
Label = np.array([[[1, 1, 0],
[1, 0, 0],
[2, 2, 2]]])
argmax_output = output.argmax(axis=1)
metric = SegmentationMetric(3)
metric.addBatch(argmax_output, Label)
pa = metric.pixelAccuracy()
cpa = metric.classPixelAccuracy()
mpa = metric.meanPixelAccuracy()
IoU = metric.IntersectionOverUnion()
mIoU = metric.meanIntersectionOverUnion()
fwmIoU = metric.FrequencyWeightedIntersectionOverUnion()
print('argmax_output\n', argmax_output)
print('混淆矩阵:\n', metric.confusionMatrix)
print('acc:\t', pa)
print('cacc:\t', cpa)
print('mcacc:\t', mpa)
print('IoU:\t', IoU)
print('mIoU:\t', mIoU)
print('FWmIoU:\t', fwmIoU)
import numpy as np
import torch
def compute_metrics(pred, label, numClass):
mask = (label >= 0) & (label < numClass)
metrics = np.bincount(
numClass * label[mask].astype(int) + pred[mask],
minlength=numClass ** 2).reshape(numClass, numClass)
return metrics
def compute_miou(predictions, label, numClass):
metrics = np.zeros((numClass, numClass))
for p, l in zip(predictions, label):
metrics += compute_metrics(p.flatten(), l.flatten(), numClass)
acc = np.diag(metrics).sum() / metrics.sum()
print(np.diag(metrics).sum())
print(metrics.sum())
metrics_sum = metrics.sum(axis=1)
mask = np.isclose(metrics_sum, 0)
cls_acc = np.diag(metrics) / np.where(mask, 1, metrics_sum)
mean_acc = np.nanmean(cls_acc)
sum_row_col = metrics.sum(axis=1) + metrics.sum(axis=0) - np.diag(metrics)
mask = np.isclose(sum_row_col, 0)
iou = np.diag(metrics) / np.where(mask, 1, sum_row_col)
miou = np.nanmean(iou)
freq = metrics.sum(axis=1) / metrics.sum()
fwmiou = (freq[freq > 0] * iou[freq > 0]).sum()
return acc, cls_acc, mean_acc, iou, miou, fwmiou
if __name__ == '__main__':
np.random.seed(13)
imgPredict = np.array([[0, 1, 2], [1, 2, 2], [2, 0, 1]])
imgLabel = np.array([[0, 1, 2], [1, 2, 2], [2, 1, 2]])
imgPredict = np.random.randint(0, 8, (2, 3, 3))
imgLabel = np.random.randint(0, 8, (2, 3, 3))
print('compute_metrics\n', compute_metrics(imgPredict, imgLabel, 8))
acc, cls_acc, mean_acc, iou, miou, fwmiou = compute_miou(imgPredict, imgLabel, 8)
print(f'acc:{
acc}\ncls_acc:{
cls_acc}\nmean_acc:{
mean_acc}\niou:{
iou}\nmiou:{
miou}\nfwmiou:{
fwmiou}')