Machine learning 2: KNN decision tree explores the issue of Titanic survivors

KNN decision tree explores the issue of Titanic survivors

Insert picture description here

import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.metrics import classification_report
import graphviz   #决策树可视化
data = pd.read_csv(r"titanic_data.csv")
data.drop("PassengerId",axis = 1,inplace = True)  #删除id这一列
data
Survived Pclass Sex Age
0 0 3 male 22.0
1 1 1 female 38.0
2 1 3 female 26.0
3 1 1 female 35.0
4 0 3 male 35.0
... ... ... ... ...
886 0 2 male 27.0
887 1 1 female 19.0
888 0 3 female NaN
889 1 1 male 26.0
890 0 3 male 32.0

891 rows × 4 columns

data.loc[data["Sex"] == "male","Sex"] = 1
data.loc[data["Sex"] == "female","Sex"] = 0
data
Survived Pclass Sex Age
0 0 3 1 22.0
1 1 1 0 38.0
2 1 3 0 26.0
3 1 1 0 35.0
4 0 3 1 35.0
... ... ... ... ...
886 0 2 1 27.0
887 1 1 0 19.0
888 0 3 0 NaN
889 1 1 1 26.0
890 0 3 1 32.0

891 rows × 4 columns

data.fillna(data["Age"].mean(),inplace = True)  #用均值来填充缺失值
data
Survived Pclass Sex Age
0 0 3 1 22.000000
1 1 1 0 38.000000
2 1 3 0 26.000000
3 1 1 0 35.000000
4 0 3 1 35.000000
... ... ... ... ...
886 0 2 1 27.000000
887 1 1 0 19.000000
888 0 3 0 29.699118
889 1 1 1 26.000000
890 0 3 1 32.000000

891 rows × 4 columns

Dtc = DecisionTreeClassifier(max_depth = 5,random_state =8)  #构建决策树
Dtc.fit(data.iloc[:,1:],data["Survived"])    #模型训练
pre = Dtc.predict(data.iloc[:,1:])  #模型预测
print(classification_report(pre,data["Survived"]))   #混淆矩阵
              precision    recall  f1-score   support

           0       0.88      0.84      0.86       573
           1       0.73      0.79      0.76       318

    accuracy                           0.82       891
   macro avg       0.81      0.82      0.81       891
weighted avg       0.83      0.82      0.82       891
pre == data["Survived"]   #比较模型预测值与实际值是否一致
0       True
1       True
2       True
3       True
4       True
       ...  
886     True
887     True
888    False
889    False
890     True
Name: Survived, Length: 891, dtype: bool

Visualization

dot_data = export_graphviz(Dtc,feature_names = ["Pclass","Sex","Age"],class_names="Survive")
graph  = graphviz.Source(dot_data)
graph

Insert picture description here

Guess you like

Origin blog.csdn.net/qq_45176548/article/details/112060492