Handwriting recognition loading digital data set
import numpy
from sklearn import datasets
import matplotlib.pyplot as plt
digits = datasets.load_digits()
x = digits.data
y = digits.target
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.8,random_state=666)
Using logistic regression training
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
# sklearn中默认使用OVR方式解决多分类问题
log_reg.fit(x_train,y_train)
y_predict = log_reg.predict(x_test)
log_reg.score(x_test,y_test)
View multiple classification confusion matrix
from sklearn.metrics import confusion_matrix
cfm = confusion_matrix(y_test,y_predict)
The gradation value data in association with:
# cmap为颜色映射,gray为像素灰度值
plt.matshow(cfm,cmap=plt.cm.gray)
Remove the right diagonal prediction data, view other values confusion matrix
row_sum = numpy.sum(cfm,axis=1)
err_matrix = cfm / row_sum
numpy.fill_diagonal(err_matrix,0)
plt.matshow(err_matrix,cmap=plt.cm.gray)
The figure can be seen not only where to mistakes more, you can also see what kind of error, for example: the algorithm will tend to forecast the data value of 1 to 9, the number 8 is predicted to 1.
In terms of the algorithm, we should consider adjusting the decision threshold 1,8,9 to enhance the accuracy of the algorithm. Handwriting recognition in data sets, data processing should be considered, such as the elimination of the noise and interference data collection point, to improve the clarity and the degree of recognition.