sklearn神经网络手写数字识别

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report,confusion_matrix
import matplotlib.pyplot as plt

digits = load_digits()
x_data = digits.data
y_data = digits.target

# 数据拆分
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data)

#构建模型,64-100-50-10,训练500周期
mlp = MLPClassifier(hidden_layer_sizes=(100,50),max_iter=500)
mlp.fit(x_train,y_train)

precdictions = mlp.predict(x_test)
#mlp.predict(x_test)是预测输出
#print(classification_report(y_test,precdictions))

猜你喜欢

转载自blog.csdn.net/weixin_44823313/article/details/112434024