Machine learning - neural network (python) handwritten digit recognition

1. Definition

It has long been believed that neural networks are designed to mimic the neural networks of living organisms. er

Neural networks can be used for both regression and classification, but they are often used for classification in practical applications. Deep learning based on neural networks is widely recognized for its excellent performance in areas such as image recognition and speech recognition.

A neuron is the basic unit of a neural network algorithm , which is essentially a function that receives external stimuli and generates corresponding outputs according to the input. Its interior can be seen as a combination of a linear function and an activation function, and the result of the linear function operation is passed to the activation function to finally generate the output of the neuron.

Two, the code

from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

digits = load_digits()
x = digits.data    #获得数据集中的输入
y = digits.target  #0-9
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.25,random_state=33)

# 我们通常将原始数据按照比例分割为“测试集”和“训练集”
# test_size:样本占比,如果是整数的话就是样本的数量
# random_state:是随机数的种子

#数据标准化操作

from sklearn.preprocessing import StandardScaler
ss = StandardScaler()
x_test_copy = x_test.copy()
# fit_transform()先拟合数据,再标准化
x_train = ss.fit_transform(x_train)
print(len(x_train))
# transform()数据标准化
x_test = ss.transform(x_test)

from sklearn.neural_network import MLPClassifier
mclf = MLPClassifier()
mclf.fit(x_train,y_train)          #用训练器数据拟合分类器模型
y_predict = mclf.predict(x_test)   #用训练器数据X拟合分类器模型并对训练器数据X进行预测

print('准确率为:',mclf.score(x_test,y_test)) #准确率

#详细评估

from sklearn.metrics import  classification_report
print(classification_report(y_test,y_predict,target_names = digits.target_names.astype(str)))
n = int(input('查看第n个答案,请输入n:'))
plt.figure(figsize=(8,8))
plt.imshow(x_test_copy[n].reshape(8,8))
plt.show()
# print(x_test_copy[1].reshape(8,8))
print( '预测结果:',y_test[n])

Guess you like

Origin blog.csdn.net/maggieyiyi/article/details/123920044