Classification algorithm -3. Multi-classification confusion matrix

Handwriting recognition loading digital data set

import numpy
from sklearn import datasets
import matplotlib.pyplot as plt 

digits = datasets.load_digits()
x = digits.data
y = digits.target

from sklearn.model_selection import train_test_split

x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.8,random_state=666)

Using logistic regression training

from sklearn.linear_model import LogisticRegression

log_reg = LogisticRegression()

# sklearn中默认使用OVR方式解决多分类问题
log_reg.fit(x_train,y_train)
y_predict = log_reg.predict(x_test)
log_reg.score(x_test,y_test)

View multiple classification confusion matrix

from sklearn.metrics import confusion_matrix

cfm = confusion_matrix(y_test,y_predict)

The gradation value data in association with:

# cmap为颜色映射,gray为像素灰度值
plt.matshow(cfm,cmap=plt.cm.gray)

Remove the right diagonal prediction data, view other values ​​confusion matrix

row_sum = numpy.sum(cfm,axis=1)
err_matrix = cfm / row_sum
numpy.fill_diagonal(err_matrix,0)

plt.matshow(err_matrix,cmap=plt.cm.gray)

The figure can be seen not only where to mistakes more, you can also see what kind of error, for example: the algorithm will tend to forecast the data value of 1 to 9, the number 8 is predicted to 1.
In terms of the algorithm, we should consider adjusting the decision threshold 1,8,9 to enhance the accuracy of the algorithm. Handwriting recognition in data sets, data processing should be considered, such as the elimination of the noise and interference data collection point, to improve the clarity and the degree of recognition.

Guess you like

Origin www.cnblogs.com/shuai-long/p/11649896.html