matplotlb 混淆矩阵 论文 用 python 画matlab产生的混淆矩阵

版权声明:原创禁止转载 https://blog.csdn.net/u013249853/article/details/88179684

论文当中经常会见到:

这样的混淆矩阵,之前我们讨论过怎么将其放到latex中,这里讨论下,怎么设置四舍五入,怎么更改坐标,怎么限制显示阈值。比如上面这张图,小于15的就不显示。我们的数据来源是一个matlab的.mat文件,读入后会自动生成字典,并且该混淆矩阵是一个np.arry。

直接看代码注释:

from numpy import *
import matplotlib.pyplot as plt
from pylab import *
import scipy.io
import math
#这里读入了matlab生成的.mat数据
data = scipy.io.loadmat('D:/GMM/nyuv1/con_82.mat')

#conf_arr = [[33,2,0,0,0,0,0,0,0,1,3], [3,31,0,0,0,0,0,0,0,0,0], [0,4,41,0,0,0,0,0,0,0,1], [0,1,0,30,0,6,0,0,0,0,1], [0,0,0,0,38,10,0,0,0,0,0], [0,0,0,3,1,39,0,0,0,0,4], [0,2,2,0,4,1,31,0,0,0,2], [0,1,0,0,0,0,0,36,0,2,0], [0,0,0,0,0,0,1,5,37,5,1], [3,0,0,0,0,0,0,0,0,39,0], [0,0,0,0,0,0,0,0,0,0,38] ]
#以字典的形式读入其中存的一个变量con_re_mm,该变量就是个数组,混淆矩阵,13*13的
data = data['con_re_mm']
conf_arr = data;
#这里四舍五入的取整,直接改变了数据形式,不然画出来的图全是99.0 88.0而不是99 88
conf_arr = conf_arr.astype(np.int)

norm_conf = []

for i in conf_arr:

    a = 0
    tmp_arr = []
    a = sum(i,0)
    for j in i:
        tmp_arr.append(float(j)/float(a))
    norm_conf.append(tmp_arr)

plt.clf()
fig = plt.figure()
ax = fig.add_subplot(111)
#这里打算将y坐标轴改成了文字,13个类别
sca_y = range(13)
y_name = ['Bed','Blind', 'Bookshelf', 'Cabinet', 'Ceiling', 'Floor', 'Picture', 'Sofa', 'Table','Tv','Wall', 'Window','Background']
res = ax.imshow(array(norm_conf), cmap=cm.jet, interpolation='nearest')
for i, cas in enumerate(conf_arr):
    for j, c in enumerate(cas):
        if c>15:
            plt.text(j-.2, i+.2, c, fontsize=14)
cb = fig.colorbar(res)
#这里将改变好的坐标轴加上
plt.yticks(sca_y,y_name)

之后保存成pdf就可以无损的加入到你的latex代码中,或者保存成其他格式其他用途

猜你喜欢

转载自blog.csdn.net/u013249853/article/details/88179684