Keras实现鸢尾花数据集分类问题

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Richard__Ting/article/details/87916967

Keras实现鸢尾花数据集分类问题

Python 实现分类代码

Keras 神经网络-分类问题。

# -*- coding: utf-8 -*-
import numpy as np # 用来做矩阵运算
import pandas as pd # 用来做数据分析
from keras.models import Sequential # 模型&序列串行的类
from keras.layers import Dense # 隐含层的节点与前后都有连接,密度很高
from keras.wrappers.scikit_learn import KerasClassifier # 一个包裹的API
from keras.utils import np_utils # 待会解释
from sklearn.model_selection import cross_val_score # 交叉验证,准确度与得分
from sklearn.model_selection import KFold # KFold 将数据集中n-1个作为训练集,1个作为测试集,进行n次
from sklearn.preprocessing import LabelEncoder # 预处理,用于将标签的字符串转换为数字
from keras.models import model_from_json # 训练好模型,最后存起来,下次用完就无需再次训练,用的时候读就可以

# reproducibility
seed = 13
np.random.seed(seed) # 种子数,随机的值是一样的

# load data
df = pd.read_csv('./Iris_data.csv') # 0-3是特征,4是类别,存了一个表格
X = df.values[:,0:4].astype(float) # numpyarray所有行的0-3列,浮点型float差不多了,不用double,以节约内存
Y = df.values[:,4] #第4列

# encode
encoder = LabelEncoder() # 实例
Y_encoded = encoder.fit_transform(Y) # 编码,字符串变为数字
Y_onehot = np_utils.to_categorical(Y_encoded) # onehot编码

# define a network
def baseline_model():
    """
    三层结构
    输入层纬度为4:与特征数目有关
    隐含层纬度为7:自定义,一般为纺锤形
    输出层纬度为3:三个类别
    """
    model = Sequential()
    # 按顺序构建网络
    model.add(Dense(7,input_dim=4,activation='tanh'))
    # 第一层,输入层到隐含层,有7个节点,输入数据纬度4维,双曲正切函数 
    model.add(Dense(3,activation='softmax'))
    # 隐含层到输出层的结构,输出层与类别的个数一样,隐藏层的节点自己定
    model.compile(loss='mean_squared_error', optimizer='sgd',metrics=['accuracy'])
    # 编译模型:用均方差来衡量网络输出的差,训练优化网络-随机梯度下降法,metrics解释如何衡量模型的好坏
    return model

estimator = KerasClassifier(build_fn=baseline_model,epochs=20,batch_size=1,verbose=1)
# 用于交叉验证,epochs为训练次数20次,batch_size批次处理为1个训练数据,输入信息的浓缩程度verbose为1

# evalute 评估系统
kfold = KFold(n_splits=10,shuffle=True,random_state=seed) 
# kfold这个定义的交叉验证的方法
# 150个数据分为10份,挑9份训练数据,1份测试数据
# shuffle随机打乱
# 使得重复结果一致
result = cross_val_score(estimator,X,Y_onehot,cv=kfold)
# 调用estimator的训练结构对象
print("Accuracy of cross validation, mean %.2f, std %.2f" %(result.mean(),result.std()))
# 打印结果

# save model 将模型存起来
estimator.fit(X,Y_onehot) # 做训练数据
model_json = estimator.model.to_json() # 将其模型转换为json
# 保存输入、隐藏、输出层结构,激活函数
with open("./model.json","w")as json_file:
    json_file.write(model_json)
    # 权重不在json中,只保存网络结构

# 储存权重
estimator.model.save_weights("model.h5")
print("saved model to disk")

# load model and use it for prediction
json_file = open("./model.json","r")
loaded_model_json = json_file.read()
json_file.close()

loaded_model = model_from_json(loaded_model_json)# 读入网络结构
loaded_model.load_weights("model.h5")# 读入权重
print("loaded model from disk")

predicted = loaded_model.predict(X) # 做预测
print("predicted probability: " + str(predicted))

predicted_label = loaded_model.predict_classes(X) # 直接说明类别是什么
print("predicted label: " + str(predicted_label))

鸢尾花数据集(csv格式)

花瓣长度,花瓣宽度,花萼长度,花萼宽度,类别
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica

部分结果

...
  1/135 [..............................] - ETA: 0s - loss: 0.0269 - acc: 1.0000
 45/135 [=========>....................] - ETA: 0s - loss: 0.0873 - acc: 0.9111
 88/135 [==================>...........] - ETA: 0s - loss: 0.1006 - acc: 0.9318
102/135 [=====================>........] - ETA: 0s - loss: 0.1026 - acc: 0.9118
111/135 [=======================>......] - ETA: 0s - loss: 0.1047 - acc: 0.8919
135/135 [==============================] - 0s 2ms/step - loss: 0.1040 - acc: 0.8815
Epoch 10/20
...
Accuracy of cross validation, mean 0.88, std 0.10
...
saved model to disk
loaded model from disk
predicted probability: [[ 0.81767237  0.15703352  0.02529415]
 [ 0.78484738  0.18383147  0.03132115]
 [ 0.80815864  0.16466212  0.02717929]
 [ 0.79991156  0.17107031  0.02901812]
 [ 0.82356691  0.15209585  0.02433728]
 [ 0.81987345  0.15500849  0.02511811]
 [ 0.81340641  0.15996811  0.0266255 ]
 ...
 predicted label: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 2 2 1 1 2 1 1 1 1 1 1 2 1 1 1 2 1 2 1
 1 1 1 2 2 1 1 1 1 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]

猜你喜欢

转载自blog.csdn.net/Richard__Ting/article/details/87916967