Python格式化输出混淆矩阵

混淆矩阵

混淆矩阵可以用来观察模型预测结果,分析模型对各种类别预测能力。它由实际情况和预测情况两个维度构成。

举个例子,你有10张图片,其中有4只狗图片,6只猫图片,这是样本数据的实际情况。此时,你用模型来预测这些样本数据,预测结果为,4只狗图片中3只被预测为狗,1只被预测为猫,而6只猫图片中3只被预测为猫,3只被预测为狗。可以使用下表所示的混淆矩阵来表示模型的预测情况:

实际 \ 预测 狗-预测 猫-预测
狗-实际 3 1
猫-实际 3 3

Python格式化输出混淆矩阵

想要实现的效果为:输入一个混淆矩阵和类别名称,能够打印输出整洁美观的内容。

在Python中,混淆矩阵可以用numpy矩阵来表示。在矩阵中,每一行都代表着实际情形,而每一列则是模型对数据的预测结果。类别名称作为表头使用,可以很直观地看出分类结果地混淆情况,究竟是哪些类别之间错分较多。类别名称可以用列表来表示。

# -*-coding: utf-8 -*-
import numpy as np


"""
	Python格式化输出混淆矩阵
	:param confusion_matrix: 混淆矩阵,一个numpy矩阵,元素均为整型
	:param type_name: 类别名称,一个字符串列表,默认为None
	:param placeholder_length: 占位符宽度,即每个数字占几位,用于对齐,默认为5
"""

def format_print_confusion_matrix(confusion_matrix, type_name=None, placeholder_length=5):
	if type_name != None:
		type_name.insert(0, 'T \ P')    # 头部插入一个元素补齐
		for tn in type_name:
			fm = '%'+str(placeholder_length)+'s'
			print(fm%tn,end='')    # 不换行输出每一列表头
		print('\n')

	for i,cm in enumerate(confusion_matrix):
		if type_name != None:
			fm = '%'+str(placeholder_length)+'s'
			print(fm%type_name[i+1],end='')    # 不换行输出每一行表头
		
		for c in cm:
			fm = '%'+str(placeholder_length)+'d'
			print(fm%c,end='')    # 不换行输出每一行元素
		print('\n')


if __name__ == '__main__':
	confusion_matrix_example = np.array([[3,1],
										 [3,3]])
	type_name_example = ['狗', '猫']
	format_print_confusion_matrix(confusion_matrix_example, type_name_example,7)


"""
 T \ P      狗      猫

      狗      3      1

      猫      3      3
"""
发布了71 篇原创文章 · 获赞 56 · 访问量 9万+

猜你喜欢

转载自blog.csdn.net/baidu_26646129/article/details/89331747