机器学习2:KNN决策树探究泰坦尼克号幸存者问题

KNN决策树探究泰坦尼克号幸存者问题

在这里插入图片描述

import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.metrics import classification_report
import graphviz   #决策树可视化
data = pd.read_csv(r"titanic_data.csv")
data.drop("PassengerId",axis = 1,inplace = True)  #删除id这一列
data
Survived Pclass Sex Age
0 0 3 male 22.0
1 1 1 female 38.0
2 1 3 female 26.0
3 1 1 female 35.0
4 0 3 male 35.0
... ... ... ... ...
886 0 2 male 27.0
887 1 1 female 19.0
888 0 3 female NaN
889 1 1 male 26.0
890 0 3 male 32.0

891 rows × 4 columns

data.loc[data["Sex"] == "male","Sex"] = 1
data.loc[data["Sex"] == "female","Sex"] = 0
data
Survived Pclass Sex Age
0 0 3 1 22.0
1 1 1 0 38.0
2 1 3 0 26.0
3 1 1 0 35.0
4 0 3 1 35.0
... ... ... ... ...
886 0 2 1 27.0
887 1 1 0 19.0
888 0 3 0 NaN
889 1 1 1 26.0
890 0 3 1 32.0

891 rows × 4 columns

data.fillna(data["Age"].mean(),inplace = True)  #用均值来填充缺失值
data
Survived Pclass Sex Age
0 0 3 1 22.000000
1 1 1 0 38.000000
2 1 3 0 26.000000
3 1 1 0 35.000000
4 0 3 1 35.000000
... ... ... ... ...
886 0 2 1 27.000000
887 1 1 0 19.000000
888 0 3 0 29.699118
889 1 1 1 26.000000
890 0 3 1 32.000000

891 rows × 4 columns

Dtc = DecisionTreeClassifier(max_depth = 5,random_state =8)  #构建决策树
Dtc.fit(data.iloc[:,1:],data["Survived"])    #模型训练
pre = Dtc.predict(data.iloc[:,1:])  #模型预测
print(classification_report(pre,data["Survived"]))   #混淆矩阵
              precision    recall  f1-score   support

           0       0.88      0.84      0.86       573
           1       0.73      0.79      0.76       318

    accuracy                           0.82       891
   macro avg       0.81      0.82      0.81       891
weighted avg       0.83      0.82      0.82       891
pre == data["Survived"]   #比较模型预测值与实际值是否一致
0       True
1       True
2       True
3       True
4       True
       ...  
886     True
887     True
888    False
889    False
890     True
Name: Survived, Length: 891, dtype: bool

可视化

dot_data = export_graphviz(Dtc,feature_names = ["Pclass","Sex","Age"],class_names="Survive")
graph  = graphviz.Source(dot_data)
graph

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_45176548/article/details/112060492