使用sklearn.MLPClassifier的简单例子

训练

MLPClassifier的hidden_layer_sizes可以设置需要的神经网络的隐藏层数及每一个隐藏层的神经元个数,比如(3,2)表示该神经网络拥有两个隐藏层,第一个隐藏层有3个神经元,第二个隐藏层有2个神经元。其他的参数具体见官方文档
下例中还使用了KFold进行了交叉检验,并存下其结果,最后将几次Fold中结果最好的分类器保存下来。

# two-layer neural network 
# train part

import numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import KFold
from joblib import dump

#get training data
X = train_data[:,1:]
y = train_data[:,0]  

#neural network classifier of structure (3,2)
kf = KFold(n_splits=3) # 3-fold cross-validation
best_clf = None
best_score = 0
train_scores = []
test_scores = []
print("kfold-------")
for train_index, test_index in kf.split(X):
    # create neural network using MLPClassifer
    clf = MLPClassifier(solver = 'sgd', activation = 'logistic', max_iter = 1000, hidden_layer_sizes = (3,2),random_state = 1)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    clf.fit(X_train, y_train)
    train_score = clf.score(X_train, y_train)
    train_scores.append(train_score)
 
    test_score = clf.score(X_test, y_test)
    test_scores.append(test_score)

    #compare score of the tree models and get the best one
    if test_score > best_score:
        best_score = test_score
        best_clf = clf
    
    #print(clf.n_outputs_)
in_sample_error = [1 - score for score in train_scores]
test_set_error = [1 - score for score in test_scores]
print("in_sample_error: ")
print(in_sample_error)
print("test_set_error: ")
print(test_set_error)

#store the classifier
if best_clf != None:
    dump(best_clf, "train_model.m")

测试

直接加载之前训练好并保存下来的分类器,并测试

# test part

import numpy as np
from sklearn.neural_network import MLPClassifier
from joblib import load

X_test = test_data[:,1:]
y_test = test_data[:,0]

clf = load("train_model.m")
y_pred = clf.predict(X_test)
np.savetxt("label_pred.txt", np.array(y_pred)) #save predict result
#print(y_pred)
test_score = clf.score(X_test, y_test)
test_error = 1 - test_score
print('test_score:%s' % test_score)
print('test_error:%s' % test_error)

猜你喜欢

转载自www.cnblogs.com/liuxin0430/p/12130346.html