Keras model for cross validation

Cross-validation

The basic idea of ​​cross-validation is to group the original data (dataset) in a certain sense, one part is used as the training set (train set), and the other part is used as the validation set or test set. The classifier is trained, and then the validation set is used to test the trained model, which is used as the performance index of the classifier.

Keras model cross validation

This case data set uses the classification data set of iris flower
import module

from sklearn.model_selection import cross_val_score
from sklearn.datasets import load_iris
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import KFold
from keras.utils.np_utils import to_categorical
import keras

Import Data

iris = load_iris()
seed = 7
X = iris.data
Y = iris.target

Observing the data,
Insert picture description here
we need to one-hot encode the label

Y_encode = to_categorical(Y)

Insert picture description here
Build a deep learning network model

def build_model():
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(7,activation='tanh',input_shape=(4,)))
    model.add(keras.layers.Dense(3,activation="softmax"))
    model.compile(loss="categorical_crossentropy",optimizer='sgd',metrics=['accuracy'])
    return model

Cross-validation

model = KerasClassifier(build_fn=build_model, epochs=20, batch_size=1)
kfold = KFold(n_splits=10,shuffle=True,random_state=seed)
result = cross_val_score(model,X,Y_encode,cv=kfold)

n_splits=10 represents ten times of cross-validation to
view cross-validation results

print("============")
print("mean:",result.mean())
print("std:",result.std())

Insert picture description here
The average accuracy of cross validation is 0.9000000059604645 and the
standard deviation of accuracy is 0.0906764646534975

Guess you like

Origin blog.csdn.net/weixin_42494845/article/details/108609800