生成2×2 或3*3 混淆矩阵(confusion matrix)的python代码

该代码可以生成2×2的混淆矩阵。每个矩阵对应的数值可以自行改变。
代码如下:

import numpy as np
import matplotlib.pyplot as plt

# ======================= 随机生成值 ========================
import numpy as np
import matplotlib.pyplot as plt

# 创建一个2x2的二分类数据矩阵。这里可以手动改变值
data = np.array([[143, 7], [5, 45]])

# 定义自定义的横轴和纵轴标签
x_labels = ['True', 'False']
y_labels = ['True', 'False']

# 绘制热力图
plt.imshow(data, cmap='plasma', interpolation='nearest')

# 显示数值
for i in range(2):
    for j in range(2):
        plt.text(j, i, str(data[i, j]), ha="center", va="center", color="white")

# 添加颜色条
plt.colorbar(ticks=[0, 1], label='Class')

# 设置横轴和纵轴标签
plt.xticks(range(2), x_labels)
plt.yticks(range(2), y_labels)
# 保存为.jpg文件
plt.savefig('/home/stu/画图22/heatmap1.jpg',dpi=800)
# 显示图形
plt.show()

输出效果:
在这里插入图片描述

该代码可以生成3×3的混淆矩阵。每个矩阵对应的数值可以自行改变。
代码如下:

import numpy as np
import matplotlib.pyplot as plt

# 修改混淆矩阵数据(3 × 3)值可以自行改变
confusion_matrix = np.array([[50, 10, 5],
                             [3, 45, 12],
                             [8, 6, 55]])

# 获取混淆矩阵的行数和列数
num_classes = confusion_matrix.shape[0]

# 绘制混淆矩阵
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Greens)
plt.title("Confusion Matrix on ISIC 2018")
plt.colorbar()

# 添加坐标轴标签
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, range(num_classes))
plt.yticks(tick_marks, range(num_classes))

# 在矩阵方块中添加文本标签
for i in range(num_classes):
    for j in range(num_classes):
        plt.text(j, i, confusion_matrix[i, j],
                 horizontalalignment="center",
                 color="white" if confusion_matrix[i, j] > (confusion_matrix.max() / 2) else "black")

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
plt.savefig('/home/stu/zy/O-Net-main/画图33/heatmap1.jpg',dpi=800)
plt.show()

输出效果:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_44025103/article/details/132149785
今日推荐