机器学习之决策树

一、导入标准库

In [1]:
# Importing the libraries 导入库
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# 使图像能够调整
%matplotlib notebook 
#中文字体显示  
plt.rc('font', family='SimHei', size=8)

二、导入数据

In [2]:
dataset = pd.read_csv('Social_Network_Ads.csv')
dataset
Out[2]:
  User ID Gender Age EstimatedSalary Purchased
0 15624510 Male 19.0 19000.0 0
1 15810944 Male 35.0 20000.0 0
2 15668575 Female 26.0 43000.0 0
3 15603246 Female 27.0 57000.0 0
4 15804002 Male 19.0 76000.0 0
5 15728773 Male 27.0 58000.0 0
6 15598044 Female 27.0 84000.0 0
7 15694829 Female 32.0 150000.0 1
8 15600575 Male 25.0 33000.0 0
9 15727311 Female 35.0 65000.0 0
10 15570769 Female 26.0 80000.0 0
11 15606274 Female 26.0 52000.0 0
12 15746139 Male 20.0 86000.0 0
13 15704987 Male 32.0 18000.0 0
14 15628972 Male 18.0 82000.0 0
15 15697686 Male 29.0 80000.0 0
16 15733883 Male 47.0 25000.0 1
17 15617482 Male 45.0 26000.0 1
18 15704583 Male 46.0 28000.0 1
19 15621083 Female 48.0 29000.0 1
20 15649487 Male 45.0 22000.0 1
21 15736760 Female 47.0 49000.0 1
22 15714658 Male 48.0 41000.0 1
23 15599081 Female 45.0 22000.0 1
24 15705113 Male 46.0 23000.0 1
25 15631159 Male 47.0 20000.0 1
26 15792818 Male 49.0 28000.0 1
27 15633531 Female 47.0 30000.0 1
28 15744529 Male 29.0 43000.0 0
29 15669656 Male 31.0 18000.0 0
... ... ... ... ... ...
370 15611430 Female 60.0 46000.0 1
371 15774744 Male 60.0 83000.0 1
372 15629885 Female 39.0 73000.0 0
373 15708791 Male 59.0 130000.0 1
374 15793890 Female 37.0 80000.0 0
375 15646091 Female 46.0 32000.0 1
376 15596984 Female 46.0 74000.0 0
377 15800215 Female 42.0 53000.0 0
378 15577806 Male 41.0 87000.0 1
379 15749381 Female 58.0 23000.0 1
380 15683758 Male 42.0 64000.0 0
381 15670615 Male 48.0 33000.0 1
382 15715622 Female 44.0 139000.0 1
383 15707634 Male 49.0 28000.0 1
384 15806901 Female 57.0 33000.0 1
385 15775335 Male 56.0 60000.0 1
386 15724150 Female 49.0 39000.0 1
387 15627220 Male 39.0 71000.0 0
388 15672330 Male 47.0 34000.0 1
389 15668521 Female 48.0 35000.0 1
390 15807837 Male 48.0 33000.0 1
391 15592570 Male 47.0 23000.0 1
392 15748589 Female 45.0 45000.0 1
393 15635893 Male 60.0 42000.0 1
394 15757632 Female 39.0 59000.0 0
395 15691863 Female 46.0 41000.0 1
396 15706071 Male 51.0 23000.0 1
397 15654296 Female 50.0 20000.0 1
398 15755018 Male 36.0 33000.0 0
399 15594041 Female 49.0 36000.0 1

400 rows × 5 columns

In [3]:
X = dataset.iloc[:, [2,3]].values # 年龄和薪水
y = dataset.iloc[:, 4].values     # 是否投放广告
X
Out[3]:
array([[  1.90000000e+01,   1.90000000e+04],
       [  3.50000000e+01,   2.00000000e+04],
       [  2.60000000e+01,   4.30000000e+04],
       [  2.70000000e+01,   5.70000000e+04],
       [  1.90000000e+01,   7.60000000e+04],
       [  2.70000000e+01,   5.80000000e+04],
       [  2.70000000e+01,   8.40000000e+04],
       [  3.20000000e+01,   1.50000000e+05],
       [  2.50000000e+01,   3.30000000e+04],
       [  3.50000000e+01,   6.50000000e+04],
       [  2.60000000e+01,   8.00000000e+04],
       [  2.60000000e+01,   5.20000000e+04],
       [  2.00000000e+01,   8.60000000e+04],
       [  3.20000000e+01,   1.80000000e+04],
       [  1.80000000e+01,   8.20000000e+04],
       [  2.90000000e+01,   8.00000000e+04],
       [  4.70000000e+01,   2.50000000e+04],
       [  4.50000000e+01,   2.60000000e+04],
       [  4.60000000e+01,   2.80000000e+04],
       [  4.80000000e+01,   2.90000000e+04],
       [  4.50000000e+01,   2.20000000e+04],
       [  4.70000000e+01,   4.90000000e+04],
       [  4.80000000e+01,   4.10000000e+04],
       [  4.50000000e+01,   2.20000000e+04],
       [  4.60000000e+01,   2.30000000e+04],
       [  4.70000000e+01,   2.00000000e+04],
       [  4.90000000e+01,   2.80000000e+04],
       [  4.70000000e+01,   3.00000000e+04],
       [  2.90000000e+01,   4.30000000e+04],
       [  3.10000000e+01,   1.80000000e+04],
       [  3.10000000e+01,   7.40000000e+04],
       [  2.70000000e+01,   1.37000000e+05],
       [  2.10000000e+01,   1.60000000e+04],
       [  2.80000000e+01,   4.40000000e+04],
       [  2.70000000e+01,   9.00000000e+04],
       [  3.50000000e+01,   2.70000000e+04],
       [  3.30000000e+01,   2.80000000e+04],
       [  3.00000000e+01,   4.90000000e+04],
       [  2.60000000e+01,   7.20000000e+04],
       [  2.70000000e+01,   3.10000000e+04],
       [  2.70000000e+01,   1.70000000e+04],
       [  3.30000000e+01,   5.10000000e+04],
       [  3.50000000e+01,   1.08000000e+05],
       [  3.00000000e+01,   1.50000000e+04],
       [  2.80000000e+01,   8.40000000e+04],
       [  2.30000000e+01,   2.00000000e+04],
       [  2.50000000e+01,   7.90000000e+04],
       [  2.70000000e+01,   5.40000000e+04],
       [  3.00000000e+01,   1.35000000e+05],
       [  3.10000000e+01,   8.90000000e+04],
       [  2.40000000e+01,   3.20000000e+04],
       [  1.80000000e+01,   4.40000000e+04],
       [  2.90000000e+01,   8.30000000e+04],
       [  3.50000000e+01,   2.30000000e+04],
       [  2.70000000e+01,   5.80000000e+04],
       [  2.40000000e+01,   5.50000000e+04],
       [  2.30000000e+01,   4.80000000e+04],
       [  2.80000000e+01,   7.90000000e+04],
       [  2.20000000e+01,   1.80000000e+04],
       [  3.20000000e+01,   1.17000000e+05],
       [  2.70000000e+01,   2.00000000e+04],
       [  2.50000000e+01,   8.70000000e+04],
       [  2.30000000e+01,   6.60000000e+04],
       [  3.20000000e+01,   1.20000000e+05],
       [  5.90000000e+01,   8.30000000e+04],
       [  2.40000000e+01,   5.80000000e+04],
       [  2.40000000e+01,   1.90000000e+04],
       [  2.30000000e+01,   8.20000000e+04],
       [  2.20000000e+01,   6.30000000e+04],
       [  3.10000000e+01,   6.80000000e+04],
       [  2.50000000e+01,   8.00000000e+04],
       [  2.40000000e+01,   2.70000000e+04],
       [  2.00000000e+01,   2.30000000e+04],
       [  3.30000000e+01,   1.13000000e+05],
       [  3.20000000e+01,   1.80000000e+04],
       [  3.40000000e+01,   1.12000000e+05],
       [  1.80000000e+01,   5.20000000e+04],
       [  2.20000000e+01,   2.70000000e+04],
       [  2.80000000e+01,   8.70000000e+04],
       [  2.60000000e+01,   1.70000000e+04],
       [  3.00000000e+01,   8.00000000e+04],
       [  3.90000000e+01,   4.20000000e+04],
       [  2.00000000e+01,   4.90000000e+04],
       [  3.50000000e+01,   8.80000000e+04],
       [  3.00000000e+01,   6.20000000e+04],
       [  3.10000000e+01,   1.18000000e+05],
       [  2.40000000e+01,   5.50000000e+04],
       [  2.80000000e+01,   8.50000000e+04],
       [  2.60000000e+01,   8.10000000e+04],
       [  3.50000000e+01,   5.00000000e+04],
       [  2.20000000e+01,   8.10000000e+04],
       [  3.00000000e+01,   1.16000000e+05],
       [  2.60000000e+01,   1.50000000e+04],
       [  2.90000000e+01,   2.80000000e+04],
       [  2.90000000e+01,   8.30000000e+04],
       [  3.50000000e+01,   4.40000000e+04],
       [  3.50000000e+01,   2.50000000e+04],
       [  2.80000000e+01,   1.23000000e+05],
       [  3.50000000e+01,   7.30000000e+04],
       [  2.80000000e+01,   3.70000000e+04],
       [  2.70000000e+01,   8.80000000e+04],
       [  2.80000000e+01,   5.90000000e+04],
       [  3.20000000e+01,   8.60000000e+04],
       [  3.30000000e+01,   1.49000000e+05],
       [  1.90000000e+01,   2.10000000e+04],
       [  2.10000000e+01,   7.20000000e+04],
       [  2.60000000e+01,   3.50000000e+04],
       [  2.70000000e+01,   8.90000000e+04],
       [  2.60000000e+01,   8.60000000e+04],
       [  3.80000000e+01,   8.00000000e+04],
       [  3.90000000e+01,   7.10000000e+04],
       [  3.70000000e+01,   7.10000000e+04],
       [  3.80000000e+01,   6.10000000e+04],
       [  3.70000000e+01,   5.50000000e+04],
       [  4.20000000e+01,   8.00000000e+04],
       [  4.00000000e+01,   5.70000000e+04],
       [  3.50000000e+01,   7.50000000e+04],
       [  3.60000000e+01,   5.20000000e+04],
       [  4.00000000e+01,   5.90000000e+04],
       [  4.10000000e+01,   5.90000000e+04],
       [  3.60000000e+01,   7.50000000e+04],
       [  3.70000000e+01,   7.20000000e+04],
       [  4.00000000e+01,   7.50000000e+04],
       [  3.50000000e+01,   5.30000000e+04],
       [  4.10000000e+01,   5.10000000e+04],
       [  3.90000000e+01,   6.10000000e+04],
       [  4.20000000e+01,   6.50000000e+04],
       [  2.60000000e+01,   3.20000000e+04],
       [  3.00000000e+01,   1.70000000e+04],
       [  2.60000000e+01,   8.40000000e+04],
       [  3.10000000e+01,   5.80000000e+04],
       [  3.30000000e+01,   3.10000000e+04],
       [  3.00000000e+01,   8.70000000e+04],
       [  2.10000000e+01,   6.80000000e+04],
       [  2.80000000e+01,   5.50000000e+04],
       [  2.30000000e+01,   6.30000000e+04],
       [  2.00000000e+01,   8.20000000e+04],
       [  3.00000000e+01,   1.07000000e+05],
       [  2.80000000e+01,   5.90000000e+04],
       [  1.90000000e+01,   2.50000000e+04],
       [  1.90000000e+01,   8.50000000e+04],
       [  1.80000000e+01,   6.80000000e+04],
       [  3.50000000e+01,   5.90000000e+04],
       [  3.00000000e+01,   8.90000000e+04],
       [  3.40000000e+01,   2.50000000e+04],
       [  2.40000000e+01,   8.90000000e+04],
       [  2.70000000e+01,   9.60000000e+04],
       [  4.10000000e+01,   3.00000000e+04],
       [  2.90000000e+01,   6.10000000e+04],
       [  2.00000000e+01,   7.40000000e+04],
       [  2.60000000e+01,   1.50000000e+04],
       [  4.10000000e+01,   4.50000000e+04],
       [  3.10000000e+01,   7.60000000e+04],
       [  3.60000000e+01,   5.00000000e+04],
       [  4.00000000e+01,   4.70000000e+04],
       [  3.10000000e+01,   1.50000000e+04],
       [  4.60000000e+01,   5.90000000e+04],
       [  2.90000000e+01,   7.50000000e+04],
       [  2.60000000e+01,   3.00000000e+04],
       [  3.20000000e+01,   1.35000000e+05],
       [  3.20000000e+01,   1.00000000e+05],
       [  2.50000000e+01,   9.00000000e+04],
       [  3.70000000e+01,   3.30000000e+04],
       [  3.50000000e+01,   3.80000000e+04],
       [  3.30000000e+01,   6.90000000e+04],
       [  1.80000000e+01,   8.60000000e+04],
       [  2.20000000e+01,   5.50000000e+04],
       [  3.50000000e+01,   7.10000000e+04],
       [  2.90000000e+01,   1.48000000e+05],
       [  2.90000000e+01,   4.70000000e+04],
       [  2.10000000e+01,   8.80000000e+04],
       [  3.40000000e+01,   1.15000000e+05],
       [  2.60000000e+01,   1.18000000e+05],
       [  3.40000000e+01,   4.30000000e+04],
       [  3.40000000e+01,   7.20000000e+04],
       [  2.30000000e+01,   2.80000000e+04],
       [  3.50000000e+01,   4.70000000e+04],
       [  2.50000000e+01,   2.20000000e+04],
       [  2.40000000e+01,   2.30000000e+04],
       [  3.10000000e+01,   3.40000000e+04],
       [  2.60000000e+01,   1.60000000e+04],
       [  3.10000000e+01,   7.10000000e+04],
       [  3.20000000e+01,   1.17000000e+05],
       [  3.30000000e+01,   4.30000000e+04],
       [  3.30000000e+01,   6.00000000e+04],
       [  3.10000000e+01,   6.60000000e+04],
       [  2.00000000e+01,   8.20000000e+04],
       [  3.30000000e+01,   4.10000000e+04],
       [  3.50000000e+01,   7.20000000e+04],
       [  2.80000000e+01,   3.20000000e+04],
       [  2.40000000e+01,   8.40000000e+04],
       [  1.90000000e+01,   2.60000000e+04],
       [  2.90000000e+01,   4.30000000e+04],
       [  1.90000000e+01,   7.00000000e+04],
       [  2.80000000e+01,   8.90000000e+04],
       [  3.40000000e+01,   4.30000000e+04],
       [  3.00000000e+01,   7.90000000e+04],
       [  2.00000000e+01,   3.60000000e+04],
       [  2.60000000e+01,   8.00000000e+04],
       [  3.50000000e+01,   2.20000000e+04],
       [  3.50000000e+01,   3.90000000e+04],
       [  4.90000000e+01,   7.40000000e+04],
       [  3.90000000e+01,   1.34000000e+05],
       [  4.10000000e+01,   7.10000000e+04],
       [  5.80000000e+01,   1.01000000e+05],
       [  4.70000000e+01,   4.70000000e+04],
       [  5.50000000e+01,   1.30000000e+05],
       [  5.20000000e+01,   1.14000000e+05],
       [  4.00000000e+01,   1.42000000e+05],
       [  4.60000000e+01,   2.20000000e+04],
       [  4.80000000e+01,   9.60000000e+04],
       [  5.20000000e+01,   1.50000000e+05],
       [  5.90000000e+01,   4.20000000e+04],
       [  3.50000000e+01,   5.80000000e+04],
       [  4.70000000e+01,   4.30000000e+04],
       [  6.00000000e+01,   1.08000000e+05],
       [  4.90000000e+01,   6.50000000e+04],
       [  4.00000000e+01,   7.80000000e+04],
       [  4.60000000e+01,   9.60000000e+04],
       [  5.90000000e+01,   1.43000000e+05],
       [  4.10000000e+01,   8.00000000e+04],
       [  3.50000000e+01,   9.10000000e+04],
       [  3.70000000e+01,   1.44000000e+05],
       [  6.00000000e+01,   1.02000000e+05],
       [  3.50000000e+01,   6.00000000e+04],
       [  3.70000000e+01,   5.30000000e+04],
       [  3.60000000e+01,   1.26000000e+05],
       [  5.60000000e+01,   1.33000000e+05],
       [  4.00000000e+01,   7.20000000e+04],
       [  4.20000000e+01,   8.00000000e+04],
       [  3.50000000e+01,   1.47000000e+05],
       [  3.90000000e+01,   4.20000000e+04],
       [  4.00000000e+01,   1.07000000e+05],
       [  4.90000000e+01,   8.60000000e+04],
       [  3.80000000e+01,   1.12000000e+05],
       [  4.60000000e+01,   7.90000000e+04],
       [  4.00000000e+01,   5.70000000e+04],
       [  3.70000000e+01,   8.00000000e+04],
       [  4.60000000e+01,   8.20000000e+04],
       [  5.30000000e+01,   1.43000000e+05],
       [  4.20000000e+01,   1.49000000e+05],
       [  3.80000000e+01,   5.90000000e+04],
       [  5.00000000e+01,   8.80000000e+04],
       [  5.60000000e+01,   1.04000000e+05],
       [  4.10000000e+01,   7.20000000e+04],
       [  5.10000000e+01,   1.46000000e+05],
       [  3.50000000e+01,   5.00000000e+04],
       [  5.70000000e+01,   1.22000000e+05],
       [  4.10000000e+01,   5.20000000e+04],
       [  3.50000000e+01,   9.70000000e+04],
       [  4.40000000e+01,   3.90000000e+04],
       [  3.70000000e+01,   5.20000000e+04],
       [  4.80000000e+01,   1.34000000e+05],
       [  3.70000000e+01,   1.46000000e+05],
       [  5.00000000e+01,   4.40000000e+04],
       [  5.20000000e+01,   9.00000000e+04],
       [  4.10000000e+01,   7.20000000e+04],
       [  4.00000000e+01,   5.70000000e+04],
       [  5.80000000e+01,   9.50000000e+04],
       [  4.50000000e+01,   1.31000000e+05],
       [  3.50000000e+01,   7.70000000e+04],
       [  3.60000000e+01,   1.44000000e+05],
       [  5.50000000e+01,   1.25000000e+05],
       [  3.50000000e+01,   7.20000000e+04],
       [  4.80000000e+01,   9.00000000e+04],
       [  4.20000000e+01,   1.08000000e+05],
       [  4.00000000e+01,   7.50000000e+04],
       [  3.70000000e+01,   7.40000000e+04],
       [  4.70000000e+01,   1.44000000e+05],
       [  4.00000000e+01,   6.10000000e+04],
       [  4.30000000e+01,   1.33000000e+05],
       [  5.90000000e+01,   7.60000000e+04],
       [  6.00000000e+01,   4.20000000e+04],
       [  3.90000000e+01,   1.06000000e+05],
       [  5.70000000e+01,   2.60000000e+04],
       [  5.70000000e+01,   7.40000000e+04],
       [  3.80000000e+01,   7.10000000e+04],
       [  4.90000000e+01,   8.80000000e+04],
       [  5.20000000e+01,   3.80000000e+04],
       [  5.00000000e+01,   3.60000000e+04],
       [  5.90000000e+01,   8.80000000e+04],
       [  3.50000000e+01,   6.10000000e+04],
       [  3.70000000e+01,   7.00000000e+04],
       [  5.20000000e+01,   2.10000000e+04],
       [  4.80000000e+01,   1.41000000e+05],
       [  3.70000000e+01,   9.30000000e+04],
       [  3.70000000e+01,   6.20000000e+04],
       [  4.80000000e+01,   1.38000000e+05],
       [  4.10000000e+01,   7.90000000e+04],
       [  3.70000000e+01,   7.80000000e+04],
       [  3.90000000e+01,   1.34000000e+05],
       [  4.90000000e+01,   8.90000000e+04],
       [  5.50000000e+01,   3.90000000e+04],
       [  3.70000000e+01,   7.70000000e+04],
       [  3.50000000e+01,   5.70000000e+04],
       [  3.60000000e+01,   6.30000000e+04],
       [  4.20000000e+01,   7.30000000e+04],
       [  4.30000000e+01,   1.12000000e+05],
       [  4.50000000e+01,   7.90000000e+04],
       [  4.60000000e+01,   1.17000000e+05],
       [  5.80000000e+01,   3.80000000e+04],
       [  4.80000000e+01,   7.40000000e+04],
       [  3.70000000e+01,   1.37000000e+05],
       [  3.70000000e+01,   7.90000000e+04],
       [  4.00000000e+01,   6.00000000e+04],
       [  4.20000000e+01,   5.40000000e+04],
       [  5.10000000e+01,   1.34000000e+05],
       [  4.70000000e+01,   1.13000000e+05],
       [  3.60000000e+01,   1.25000000e+05],
       [  3.80000000e+01,   5.00000000e+04],
       [  4.20000000e+01,   7.00000000e+04],
       [  3.90000000e+01,   9.60000000e+04],
       [  3.80000000e+01,   5.00000000e+04],
       [  4.90000000e+01,   1.41000000e+05],
       [  3.90000000e+01,   7.90000000e+04],
       [  3.90000000e+01,   7.50000000e+04],
       [  5.40000000e+01,   1.04000000e+05],
       [  3.50000000e+01,   5.50000000e+04],
       [  4.50000000e+01,   3.20000000e+04],
       [  3.60000000e+01,   6.00000000e+04],
       [  5.20000000e+01,   1.38000000e+05],
       [  5.30000000e+01,   8.20000000e+04],
       [  4.10000000e+01,   5.20000000e+04],
       [  4.80000000e+01,   3.00000000e+04],
       [  4.80000000e+01,   1.31000000e+05],
       [  4.10000000e+01,   6.00000000e+04],
       [  4.10000000e+01,   7.20000000e+04],
       [  4.20000000e+01,   7.50000000e+04],
       [  3.60000000e+01,   1.18000000e+05],
       [  4.70000000e+01,   1.07000000e+05],
       [  3.80000000e+01,   5.10000000e+04],
       [  4.80000000e+01,   1.19000000e+05],
       [  4.20000000e+01,   6.50000000e+04],
       [  4.00000000e+01,   6.50000000e+04],
       [  5.70000000e+01,   6.00000000e+04],
       [  3.60000000e+01,   5.40000000e+04],
       [  5.80000000e+01,   1.44000000e+05],
       [  3.50000000e+01,   7.90000000e+04],
       [  3.80000000e+01,   5.50000000e+04],
       [  3.90000000e+01,   1.22000000e+05],
       [  5.30000000e+01,   1.04000000e+05],
       [  3.50000000e+01,   7.50000000e+04],
       [  3.80000000e+01,   6.50000000e+04],
       [  4.70000000e+01,   5.10000000e+04],
       [  4.70000000e+01,   1.05000000e+05],
       [  4.10000000e+01,   6.30000000e+04],
       [  5.30000000e+01,   7.20000000e+04],
       [  5.40000000e+01,   1.08000000e+05],
       [  3.90000000e+01,   7.70000000e+04],
       [  3.80000000e+01,   6.10000000e+04],
       [  3.80000000e+01,   1.13000000e+05],
       [  3.70000000e+01,   7.50000000e+04],
       [  4.20000000e+01,   9.00000000e+04],
       [  3.70000000e+01,   5.70000000e+04],
       [  3.60000000e+01,   9.90000000e+04],
       [  6.00000000e+01,   3.40000000e+04],
       [  5.40000000e+01,   7.00000000e+04],
       [  4.10000000e+01,   7.20000000e+04],
       [  4.00000000e+01,   7.10000000e+04],
       [  4.20000000e+01,   5.40000000e+04],
       [  4.30000000e+01,   1.29000000e+05],
       [  5.30000000e+01,   3.40000000e+04],
       [  4.70000000e+01,   5.00000000e+04],
       [  4.20000000e+01,   7.90000000e+04],
       [  4.20000000e+01,   1.04000000e+05],
       [  5.90000000e+01,   2.90000000e+04],
       [  5.80000000e+01,   4.70000000e+04],
       [  4.60000000e+01,   8.80000000e+04],
       [  3.80000000e+01,   7.10000000e+04],
       [  5.40000000e+01,   2.60000000e+04],
       [  6.00000000e+01,   4.60000000e+04],
       [  6.00000000e+01,   8.30000000e+04],
       [  3.90000000e+01,   7.30000000e+04],
       [  5.90000000e+01,   1.30000000e+05],
       [  3.70000000e+01,   8.00000000e+04],
       [  4.60000000e+01,   3.20000000e+04],
       [  4.60000000e+01,   7.40000000e+04],
       [  4.20000000e+01,   5.30000000e+04],
       [  4.10000000e+01,   8.70000000e+04],
       [  5.80000000e+01,   2.30000000e+04],
       [  4.20000000e+01,   6.40000000e+04],
       [  4.80000000e+01,   3.30000000e+04],
       [  4.40000000e+01,   1.39000000e+05],
       [  4.90000000e+01,   2.80000000e+04],
       [  5.70000000e+01,   3.30000000e+04],
       [  5.60000000e+01,   6.00000000e+04],
       [  4.90000000e+01,   3.90000000e+04],
       [  3.90000000e+01,   7.10000000e+04],
       [  4.70000000e+01,   3.40000000e+04],
       [  4.80000000e+01,   3.50000000e+04],
       [  4.80000000e+01,   3.30000000e+04],
       [  4.70000000e+01,   2.30000000e+04],
       [  4.50000000e+01,   4.50000000e+04],
       [  6.00000000e+01,   4.20000000e+04],
       [  3.90000000e+01,   5.90000000e+04],
       [  4.60000000e+01,   4.10000000e+04],
       [  5.10000000e+01,   2.30000000e+04],
       [  5.00000000e+01,   2.00000000e+04],
       [  3.60000000e+01,   3.30000000e+04],
       [  4.90000000e+01,   3.60000000e+04]])
In [4]:
y
Out[4]:
array([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1,
       0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1,
       1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1,
       1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1,
       0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0,
       1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1,
       0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1,
       0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
       0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
       1, 1, 1, 0, 1, 1, 1, 0, 1], dtype=int64)

三、区分训练集和测试集

In [5]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25, random_state = 0)

四、特征缩放:决策树可以不进行特征缩放,更加直观的看到数据分类

In [6]:
from sklearn.preprocessing import StandardScaler  # 导入标准化库
sc_X = StandardScaler()
X_train = sc_X.fit_transform(X_train)
X_test = sc_X.transform(X_test)
X_test
Out[6]:
array([[-0.80480212,  0.50496393],
       [-0.01254409, -0.5677824 ],
       [-0.30964085,  0.1570462 ],
       [-0.80480212,  0.27301877],
       [-0.30964085, -0.5677824 ],
       [-1.10189888, -1.43757673],
       [-0.70576986, -1.58254245],
       [-0.21060859,  2.15757314],
       [-1.99318916, -0.04590581],
       [ 0.8787462 , -0.77073441],
       [-0.80480212, -0.59677555],
       [-1.00286662, -0.42281668],
       [-0.11157634, -0.42281668],
       [ 0.08648817,  0.21503249],
       [-1.79512465,  0.47597078],
       [-0.60673761,  1.37475825],
       [-0.11157634,  0.21503249],
       [-1.89415691,  0.44697764],
       [ 1.67100423,  1.75166912],
       [-0.30964085, -1.37959044],
       [-0.30964085, -0.65476184],
       [ 0.8787462 ,  2.15757314],
       [ 0.28455268, -0.53878926],
       [ 0.8787462 ,  1.02684052],
       [-1.49802789, -1.20563157],
       [ 1.07681071,  2.07059371],
       [-1.00286662,  0.50496393],
       [-0.90383437,  0.30201192],
       [-0.11157634, -0.21986468],
       [-0.60673761,  0.47597078],
       [-1.6960924 ,  0.53395707],
       [-0.11157634,  0.27301877],
       [ 1.86906873, -0.27785096],
       [-0.11157634, -0.48080297],
       [-1.39899564, -0.33583725],
       [-1.99318916, -0.50979612],
       [-1.59706014,  0.33100506],
       [-0.4086731 , -0.77073441],
       [-0.70576986, -1.03167271],
       [ 1.07681071, -0.97368642],
       [-1.10189888,  0.53395707],
       [ 0.28455268, -0.50979612],
       [-1.10189888,  0.41798449],
       [-0.30964085, -1.43757673],
       [ 0.48261718,  1.22979253],
       [-1.10189888, -0.33583725],
       [-0.11157634,  0.30201192],
       [ 1.37390747,  0.59194336],
       [-1.20093113, -1.14764529],
       [ 1.07681071,  0.47597078],
       [ 1.86906873,  1.51972397],
       [-0.4086731 , -1.29261101],
       [-0.30964085, -0.3648304 ],
       [-0.4086731 ,  1.31677196],
       [ 2.06713324,  0.53395707],
       [ 0.68068169, -1.089659  ],
       [-0.90383437,  0.38899135],
       [-1.20093113,  0.30201192],
       [ 1.07681071, -1.20563157],
       [-1.49802789, -1.43757673],
       [-0.60673761, -1.49556302],
       [ 2.1661655 , -0.79972756],
       [-1.89415691,  0.18603934],
       [-0.21060859,  0.85288166],
       [-1.89415691, -1.26361786],
       [ 2.1661655 ,  0.38899135],
       [-1.39899564,  0.56295021],
       [-1.10189888, -0.33583725],
       [ 0.18552042, -0.65476184],
       [ 0.38358493,  0.01208048],
       [-0.60673761,  2.331532  ],
       [-0.30964085,  0.21503249],
       [-1.59706014, -0.19087153],
       [ 0.68068169, -1.37959044],
       [-1.10189888,  0.56295021],
       [-1.99318916,  0.35999821],
       [ 0.38358493,  0.27301877],
       [ 0.18552042, -0.27785096],
       [ 1.47293972, -1.03167271],
       [ 0.8787462 ,  1.08482681],
       [ 1.96810099,  2.15757314],
       [ 2.06713324,  0.38899135],
       [-1.39899564, -0.42281668],
       [-1.20093113, -1.00267957],
       [ 1.96810099, -0.91570013],
       [ 0.38358493,  0.30201192],
       [ 0.18552042,  0.1570462 ],
       [ 2.06713324,  1.75166912],
       [ 0.77971394, -0.8287207 ],
       [ 0.28455268, -0.27785096],
       [ 0.38358493, -0.16187839],
       [-0.11157634,  2.21555943],
       [-1.49802789, -0.62576869],
       [-1.29996338, -1.06066585],
       [-1.39899564,  0.41798449],
       [-1.10189888,  0.76590222],
       [-1.49802789, -0.19087153],
       [ 0.97777845, -1.06066585],
       [ 0.97777845,  0.59194336],
       [ 0.38358493,  0.99784738]])

五、决策树训练

In [7]:
from sklearn.tree import DecisionTreeClassifier
classifier = DecisionTreeClassifier(criterion = "entropy",random_state = 0)
classifier.fit(X_train, y_train)
Out[7]:
DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=0, splitter='best')

六、预测测试集的结果

In [8]:
y_pred = classifier.predict(X_test)

七、混淆矩阵

In [9]:
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test,y_pred)
cm
Out[9]:
array([[62,  6],
       [ 3, 29]])

八、图形展示预测结果

训练集

In [12]:
from matplotlib.colors import ListedColormap # 导入类
X_set, y_set = X_train, y_train             # 赋值变量
X1, X2 = np.meshgrid(np.arange(start = X_set[:, 0].min() - 1, stop = X_set[:, 0].max() + 1, step = 0.01),
                     np.arange(start = X_set[:, 1].min() - 1, stop = X_set[:, 1].max() + 1, step = 0.01))
plt.contourf(X1, X2, classifier.predict(np.array([X1.ravel(), X2.ravel()]).T).reshape(X1.shape),
             alpha = 0.75, cmap = ListedColormap(('red', 'green'))) # 涂色红色和绿色
plt.xlim(X1.min(), X1.max()) # 标注最大值,最小值
plt.ylim(X2.min(), X2.max()) # 标注最大值,最小值
for i, j in enumerate(np.unique(y_set)):
    plt.scatter(X_set[y_set == j, 0], X_set[y_set == j, 1],
                c = ListedColormap(('orange', 'blue'))(i), label = j)
plt.title(u'决策树(训练集)')
plt.xlabel(u'年龄')
plt.ylabel(u'薪水')
plt.legend()
plt.show()

测试集

In [13]:
from matplotlib.colors import ListedColormap
X_set, y_set = X_test, y_test
X1, X2 = np.meshgrid(np.arange(start = X_set[:, 0].min() - 1, stop = X_set[:, 0].max() + 1, step = 0.01),
                     np.arange(start = X_set[:, 1].min() - 1, stop = X_set[:, 1].max() + 1, step = 0.01))
plt.contourf(X1, X2, classifier.predict(np.array([X1.ravel(), X2.ravel()]).T).reshape(X1.shape),
             alpha = 0.75, cmap = ListedColormap(('red', 'green')))
plt.xlim(X1.min(), X1.max())
plt.ylim(X2.min(), X2.max())
for i, j in enumerate(np.unique(y_set)):
    plt.scatter(X_set[y_set == j, 0], X_set[y_set == j, 1],
                c = ListedColormap(('orange', 'blue'))(i), label = j)
plt.title(u'决策树(测试集)')
plt.xlabel(u'年龄')
plt.ylabel(u'薪水')
plt.legend()
plt.show()

训练集很好,测试集很烂,这个就是过拟合了。

九、优化

In [56]:
from sklearn.tree import DecisionTreeClassifier
classifier = DecisionTreeClassifier(criterion = "entropy",random_state = 0,min_samples_leaf = 8)
classifier.fit(X_train, y_train)

y_pred = classifier.predict(X_test)

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test,y_pred)
cm
Out[56]:
array([[67,  1],
       [ 3, 29]])
In [58]:
from matplotlib.colors import ListedColormap # 导入类
X_set, y_set = X_train, y_train             # 赋值变量
X1, X2 = np.meshgrid(np.arange(start = X_set[:, 0].min() - 1, stop = X_set[:, 0].max() + 1, step = 0.01),
                     np.arange(start = X_set[:, 1].min() - 1, stop = X_set[:, 1].max() + 1, step = 0.01))
plt.contourf(X1, X2, classifier.predict(np.array([X1.ravel(), X2.ravel()]).T).reshape(X1.shape),
             alpha = 0.75, cmap = ListedColormap(('red', 'green'))) # 涂色红色和绿色
plt.xlim(X1.min(), X1.max()) # 标注最大值,最小值
plt.ylim(X2.min(), X2.max()) # 标注最大值,最小值
for i, j in enumerate(np.unique(y_set)):
    plt.scatter(X_set[y_set == j, 0], X_set[y_set == j, 1],
                c = ListedColormap(('orange', 'blue'))(i), label = j)
plt.title(u'决策树(训练集)')
plt.xlabel(u'年龄')
plt.ylabel(u'薪水')
plt.legend()
plt.show()

In [59]:
from matplotlib.colors import ListedColormap
X_set, y_set = X_test, y_test
X1, X2 = np.meshgrid(np.arange(start = X_set[:, 0].min() - 1, stop = X_set[:, 0].max() + 1, step = 0.01),
                     np.arange(start = X_set[:, 1].min() - 1, stop = X_set[:, 1].max() + 1, step = 0.01))
plt.contourf(X1, X2, classifier.predict(np.array([X1.ravel(), X2.ravel()]).T).reshape(X1.shape),
             alpha = 0.75, cmap = ListedColormap(('red', 'green')))
plt.xlim(X1.min(), X1.max())
plt.ylim(X2.min(), X2.max())
for i, j in enumerate(np.unique(y_set)):
    plt.scatter(X_set[y_set == j, 0], X_set[y_set == j, 1],
                c = ListedColormap(('orange', 'blue'))(i), label = j)
plt.title(u'决策树(测试集)')
plt.xlabel(u'年龄')
plt.ylabel(u'薪水')
plt.legend()
plt.show()

调整了min_samples_leaf = 8后,结果还是喜人的,仅有4个错误值。

九、项目地址

猜你喜欢

转载自blog.csdn.net/u013584315/article/details/80202517