机器学习实战1.4之逻辑回归:多分类问题解决方法

版权声明:此文章有作者原创,涉及相关版本问题可以联系作者,[email protected] https://blog.csdn.net/weixin_42600072/article/details/88832625

主要思想:将多分类问题转换为多个二分类问题。

import pandas as pd
import matplotlib.pyplot as plt
#由于数据集中每列数据没有标签,因此需要先手动添加,且用空格来隔开
columns = [
    'mpg', 'cylinders', 'displacement', 'horsepower', 'weight', 'acceleration',
    'model year', 'origin', 'car name'
]
cars = pd.read_table('./data/auto-mpg.data', delim_whitespace=True, names = columns)
print(cars.head())
    mpg  cylinders  displacement horsepower  weight  acceleration  model year  \
0  18.0          8         307.0      130.0  3504.0          12.0          70   
1  15.0          8         350.0      165.0  3693.0          11.5          70   
2  18.0          8         318.0      150.0  3436.0          11.0          70   
3  16.0          8         304.0      150.0  3433.0          12.0          70   
4  17.0          8         302.0      140.0  3449.0          10.5          70   

   origin                   car name  
0       1  chevrolet chevelle malibu  
1       1          buick skylark 320  
2       1         plymouth satellite  
3       1              amc rebel sst  
4       1                ford torino  

pandas.get_dummies()函数生成多分类标签

dummy_cylinders = pd.get_dummies(cars['cylinders'], prefix='cyl')
#print(dummy_cylinders.head())
cars = pd.concat([cars, dummy_cylinders], axis=1)
#print(cars.head())
dummy_years = pd.get_dummies(cars['model year'], prefix='year')
cars = pd.concat([cars, dummy_years], axis=1)
cars = cars.drop('model year', axis=1)
cars = cars.drop('cylinders', axis=1)
print(cars.head())
    mpg  displacement horsepower  weight  acceleration  origin  \
0  18.0         307.0      130.0  3504.0          12.0       1   
1  15.0         350.0      165.0  3693.0          11.5       1   
2  18.0         318.0      150.0  3436.0          11.0       1   
3  16.0         304.0      150.0  3433.0          12.0       1   
4  17.0         302.0      140.0  3449.0          10.5       1   

                    car name  cyl_3  cyl_4  cyl_5   ...     year_73  year_74  \
0  chevrolet chevelle malibu      0      0      0   ...           0        0   
1          buick skylark 320      0      0      0   ...           0        0   
2         plymouth satellite      0      0      0   ...           0        0   
3              amc rebel sst      0      0      0   ...           0        0   
4                ford torino      0      0      0   ...           0        0   

   year_75  year_76  year_77  year_78  year_79  year_80  year_81  year_82  
0        0        0        0        0        0        0        0        0  
1        0        0        0        0        0        0        0        0  
2        0        0        0        0        0        0        0        0  
3        0        0        0        0        0        0        0        0  
4        0        0        0        0        0        0        0        0  

[5 rows x 25 columns]
import numpy as np
shuffled_rows = np.random.permutation(cars.index)
shuffled_cars = cars.iloc[shuffled_rows]
highest_train_row = int(cars.shape[0] * .70)
train = shuffled_cars.iloc[0:highest_train_row]
test = shuffled_cars.iloc[highest_train_row:]
from sklearn.linear_model import LogisticRegression
unique_origins = cars['origin'].unique()
unique_origins.sort()

models = {}
features = [c for c in train.columns if c.startswith('cyl') or c.startswith('year')]

for origin in unique_origins:
    model = LogisticRegression()
    
    X_train = train[features]
    y_train = train['origin'] == origin

    model.fit(X_train, y_train)
    models[origin] = model
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\linear_model\logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.
  FutureWarning)
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\linear_model\logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.
  FutureWarning)
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\linear_model\logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.
  FutureWarning)
testing_probs = pd.DataFrame(columns=unique_origins)  
print(testing_probs)

for origin in unique_origins:
    # Select testing features.
    X_test = test[features]   
    # Compute probability of observation being in the origin.
    testing_probs[origin] = models[origin].predict_proba(X_test)[:,1]
print(testing_probs)
Empty DataFrame
Columns: [1, 2, 3]
Index: []
            1         2         3
0    0.582297  0.126099  0.306066
1    0.817356  0.067122  0.129694
2    0.272403  0.469606  0.262321
3    0.582297  0.126099  0.306066
4    0.365031  0.282050  0.337488
5    0.321619  0.321747  0.347881
6    0.265345  0.497703  0.247373
7    0.272403  0.469606  0.262321
8    0.836376  0.072543  0.104897
9    0.582297  0.126099  0.306066
10   0.323216  0.487932  0.189434
11   0.821074  0.068240  0.123570
12   0.817356  0.067122  0.129694
13   0.327120  0.325627  0.335424
14   0.316809  0.166039  0.514284
15   0.959333  0.031030  0.026298
16   0.323235  0.412184  0.256623
17   0.582297  0.126099  0.306066
18   0.973135  0.014010  0.037249
19   0.264071  0.355003  0.384353
20   0.323216  0.487932  0.189434
21   0.265345  0.497703  0.247373
22   0.582297  0.126099  0.306066
23   0.323235  0.412184  0.256623
24   0.582297  0.126099  0.306066
25   0.854075  0.086848  0.074716
26   0.316809  0.166039  0.514284
27   0.836376  0.072543  0.104897
28   0.264071  0.355003  0.384353
29   0.582297  0.126099  0.306066
..        ...       ...       ...
90   0.351290  0.340231  0.295533
91   0.818445  0.126277  0.061284
92   0.772058  0.077050  0.148501
93   0.382735  0.385388  0.224242
94   0.959333  0.031030  0.026298
95   0.975043  0.022176  0.021483
96   0.967840  0.024734  0.025550
97   0.958345  0.040700  0.017952
98   0.382735  0.385388  0.224242
99   0.957909  0.034598  0.024356
100  0.321619  0.321747  0.347881
101  0.958345  0.040700  0.017952
102  0.272403  0.469606  0.262321
103  0.316809  0.166039  0.514284
104  0.316809  0.166039  0.514284
105  0.975043  0.022176  0.021483
106  0.582297  0.126099  0.306066
107  0.323235  0.412184  0.256623
108  0.975043  0.022176  0.021483
109  0.814029  0.029313  0.228263
110  0.316809  0.166039  0.514284
111  0.821074  0.068240  0.123570
112  0.959333  0.031030  0.026298
113  0.323235  0.412184  0.256623
114  0.817356  0.067122  0.129694
115  0.382735  0.385388  0.224242
116  0.814029  0.029313  0.228263
117  0.265345  0.497703  0.247373
118  0.323216  0.487932  0.189434
119  0.836376  0.072543  0.104897

[120 rows x 3 columns]

猜你喜欢

转载自blog.csdn.net/weixin_42600072/article/details/88832625