PySpark 实现Logistic Regression模型

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u014281392/article/details/89501105

Logistic Regression

Although it is used for classification, it’s still called logistic regression .This is due to the linear regression equations still operate to find the relationship between input variable and the target variables. the main distinction between linear and logistic regression is that we use the some sort of nonlinear function to conver the output , and restrict it between 0 and 1.

虽然它是用来分类的,但它仍然被称为逻辑回归,这是因为线性回归方程仍然可以找到输入变量和目标变量之间的关系。线性回归和逻辑回归的主要区别在于,我们使用某种非线性函数来转换输出,并将其限制在0和1之间。

logisitc regression = nonlinear function(linear regression)

linear regression :

y = b 0 + b 1 x y = b_0 + b_1 * x

nonlinear function:

1 1 + e x \frac{1}{1 + e^{-x}}

logistic regression :

P r o b a b i l i t y = 1 1 + e x Probability = \frac{1}{1 + e^{-x}}

P r o b a b i l i t y = 1 1 + e ( b 0 + b 1 x ) Probability = \frac{1}{1 + e^{-(b_0 + b_1 * x)}}

Model Evaluation

  • True Positives
    • Actual calss: 1
    • ML Model Prediction Class: 1
  • True Negatives
    • Actual Class: 0
    • ML Model Prediction Class: 0
  • False Positives
    • Actual Class: 0
    • ML Model Prediction Class: 1
  • False Negatives
    • Actual Class: 1
    • ML Model Prediction Class: 0

Accuracy

( T P + T N ) T o a t a l n u m b e r o f r e c o r d s \frac{(TP + TN)}{Toatal number of records}

Recall

( T P ) ( T P + F N ) \frac{(TP)}{(TP + FN)}

Precision

( T P ) ( T P + F P ) \frac{(TP)}{(TP + FP)}

F1 Score

F 1 S c o r e = 2 ( P r e c i s i o n R e c a l l ) ( P r e c i s i o n + R e c a l l ) F1 Score = 2* \frac{(Precision * Recall)}{(Precision +Recall)}

ROC Curve

用来确定模型的阈值,平衡准确率和召回率。

Building a logistic regression model

  • dataset : 20000行,6列。包含用户的国籍,年龄,使用的搜索引擎信息,浏览网页,是否回头客,等信息,预测顾客是否购买这一行为。
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('lg').getOrCreate()

# load data
df = spark.read.csv('./Data/Log_Reg_dataset.csv', inferSchema=True, header=True)

EDA

print((df.count(), len(df.columns)))
(20000, 6)
df.printSchema()
root
 |-- Country: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Repeat_Visitor: integer (nullable = true)
 |-- Platform: string (nullable = true)
 |-- Web_pages_viewed: integer (nullable = true)
 |-- Status: integer (nullable = true)
df.select('Country', 'Platform').show(5)
+---------+--------+
|  Country|Platform|
+---------+--------+
|    India|   Yahoo|
|   Brazil|   Yahoo|
|   Brazil|  Google|
|Indonesia|    Bing|
| Malaysia|  Google|
+---------+--------+
only showing top 5 rows
df.select('Age', 'Repeat_Visitor', 'Web_pages_viewed', 'Status').describe().show()
+-------+-----------------+-----------------+-----------------+------------------+
|summary|              Age|   Repeat_Visitor| Web_pages_viewed|            Status|
+-------+-----------------+-----------------+-----------------+------------------+
|  count|            20000|            20000|            20000|             20000|
|   mean|         28.53955|           0.5029|           9.5533|               0.5|
| stddev|7.888912950773227|0.500004090187782|6.073903499824976|0.5000125004687693|
|    min|               17|                0|                1|                 0|
|    max|              111|                1|               29|                 1|
+-------+-----------------+-----------------+-----------------+------------------+
# count  country
df.groupBy('Country').count().show()
+---------+-----+
|  Country|count|
+---------+-----+
| Malaysia| 1218|
|    India| 4018|
|Indonesia|12178|
|   Brazil| 2586|
+---------+-----+
df.groupBy('Platform').count().show()
+--------+-----+
|Platform|count|
+--------+-----+
|   Yahoo| 9859|
|    Bing| 4360|
|  Google| 5781|
+--------+-----+
df.groupBy('Repeat_Visitor').count().show()
+--------------+-----+
|Repeat_Visitor|count|
+--------------+-----+
|             1|10058|
|             0| 9942|
+--------------+-----+
df.groupBy('Status').count().show()
+------+-----+
|Status|count|
+------+-----+
|     1|10000|
|     0|10000|
+------+-----+
df.select('Country', 'Age').groupBy('Country').mean().show()
+---------+------------------+
|  Country|          avg(Age)|
+---------+------------------+
| Malaysia|27.792282430213465|
|    India|27.976854156296664|
|Indonesia| 28.43159796354081|
|   Brazil|30.274168600154677|
+---------+------------------+
df.select('Age','Repeat_Visitor','Web_pages_viewed','Status').groupBy('Status').mean().show()
+------+--------+-------------------+---------------------+-----------+
|Status|avg(Age)|avg(Repeat_Visitor)|avg(Web_pages_viewed)|avg(Status)|
+------+--------+-------------------+---------------------+-----------+
|     1| 26.5435|             0.7019|              14.5617|        1.0|
|     0| 30.5356|             0.3039|               4.5449|        0.0|
+------+--------+-------------------+---------------------+-----------+

Feature Engineering

  • 把categorical变量转换为数值
  • 把输入特征合并到一列
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import OneHotEncoder
# label encode
platform_indexer = StringIndexer(inputCol='Platform', outputCol='Platform_num').fit(df)

df = platform_indexer.transform(df)

df.show(5)
+---------+---+--------------+--------+----------------+------+------------+
|  Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_num|
+---------+---+--------------+--------+----------------+------+------------+
|    India| 41|             1|   Yahoo|              21|     1|         0.0|
|   Brazil| 28|             1|   Yahoo|               5|     0|         0.0|
|   Brazil| 40|             0|  Google|               3|     0|         1.0|
|Indonesia| 31|             1|    Bing|              15|     1|         2.0|
| Malaysia| 32|             0|  Google|              15|     1|         1.0|
+---------+---+--------------+--------+----------------+------+------------+
only showing top 5 rows
# one-hot encode
platform_onehoter = OneHotEncoder(inputCol='Platform_num', outputCol='platform_vector')
df = platform_onehoter.transform(df)
df.show(3)
+-------+---+--------------+--------+----------------+------+------------+---------------+
|Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_num|platform_vector|
+-------+---+--------------+--------+----------------+------+------------+---------------+
|  India| 41|             1|   Yahoo|              21|     1|         0.0|  (2,[0],[1.0])|
| Brazil| 28|             1|   Yahoo|               5|     0|         0.0|  (2,[0],[1.0])|
| Brazil| 40|             0|  Google|               3|     0|         1.0|  (2,[1],[1.0])|
+-------+---+--------------+--------+----------------+------+------------+---------------+
only showing top 3 rows
  • (2, [0], [1.0]) : 0, 1
  • (2, [1], [1.0]) : 1, 1
  • (2, [], []) : 0, 0
  • 这种表示,节省内存,计算更快
# label encode country
country_indexer = StringIndexer(inputCol='Country', outputCol='Country_num').fit(df)
df = country_indexer.transform(df)
# one-hot encode
country_onehoter = OneHotEncoder(inputCol='Country_num', outputCol='Country_vector')
df = country_onehoter.transform(df)
df.select(['Country', 'Country_Num', 'Country_vector']).show(3, False)
+-------+-----------+--------------+
|Country|Country_Num|Country_vector|
+-------+-----------+--------------+
|India  |1.0        |(3,[1],[1.0]) |
|Brazil |2.0        |(3,[2],[1.0]) |
|Brazil |2.0        |(3,[2],[1.0]) |
+-------+-----------+--------------+
only showing top 3 rows
df_assembler = VectorAssembler(inputCols=['platform_vector', 'Country_vector', 'Age', 'Repeat_Visitor', 'Web_pages_viewed'],
                                                      outputCol='features')
df = df_assembler.transform(df)
df.show(3)
+-------+---+--------------+--------+----------------+------+------------+---------------+-----------+--------------+--------------------+
|Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_num|platform_vector|Country_num|Country_vector|            features|
+-------+---+--------------+--------+----------------+------+------------+---------------+-----------+--------------+--------------------+
|  India| 41|             1|   Yahoo|              21|     1|         0.0|  (2,[0],[1.0])|        1.0| (3,[1],[1.0])|[1.0,0.0,0.0,1.0,...|
| Brazil| 28|             1|   Yahoo|               5|     0|         0.0|  (2,[0],[1.0])|        2.0| (3,[2],[1.0])|[1.0,0.0,0.0,0.0,...|
| Brazil| 40|             0|  Google|               3|     0|         1.0|  (2,[1],[1.0])|        2.0| (3,[2],[1.0])|(8,[1,4,5,7],[1.0...|
+-------+---+--------------+--------+----------------+------+------------+---------------+-----------+--------------+--------------------+
only showing top 3 rows
df.select(['features', 'Status']).show(5, False)
+-----------------------------------+------+
|features                           |Status|
+-----------------------------------+------+
|[1.0,0.0,0.0,1.0,0.0,41.0,1.0,21.0]|1     |
|[1.0,0.0,0.0,0.0,1.0,28.0,1.0,5.0] |0     |
|(8,[1,4,5,7],[1.0,1.0,40.0,3.0])   |0     |
|(8,[2,5,6,7],[1.0,31.0,1.0,15.0])  |1     |
|(8,[1,5,7],[1.0,32.0,15.0])        |1     |
+-----------------------------------+------+
only showing top 5 rows

splitting the dataset

data_set = df.select(['features', 'Status'])
train_df, test_df = data_set.randomSplit([0.75, 0.25])
print(' train_df shape : (%d , %d)'%(train_df.count(), len(train_df.columns)))
print(' test_df  shape: :(%d , %d)'%(test_df.count(), len(test_df.columns)))
 train_df shape : (15024 , 2)
 test_df  shape: :(4976 , 2)

Train Logistic Regression Model

from pyspark.ml.classification import LogisticRegression

log_reg = LogisticRegression(labelCol = 'Status').fit(train_df)

train_pred = log_reg.evaluate(train_df).predictions

train_pred.filter(train_pred['Status'] == 1).filter(train_pred['prediction'] == 1).select(['Status', 'prediction', 'probability']).show(10, False)
+------+----------+----------------------------------------+
|Status|prediction|probability                             |
+------+----------+----------------------------------------+
|1     |1.0       |[0.2936888208146831,0.7063111791853168] |
|1     |1.0       |[0.2936888208146831,0.7063111791853168] |
|1     |1.0       |[0.16371245468320667,0.8362875453167934]|
|1     |1.0       |[0.16371245468320667,0.8362875453167934]|
|1     |1.0       |[0.16371245468320667,0.8362875453167934]|
|1     |1.0       |[0.16371245468320667,0.8362875453167934]|
|1     |1.0       |[0.08438651069737801,0.9156134893026219]|
|1     |1.0       |[0.08438651069737801,0.9156134893026219]|
|1     |1.0       |[0.08438651069737801,0.9156134893026219]|
|1     |1.0       |[0.04158614711493927,0.9584138528850608]|
+------+----------+----------------------------------------+
only showing top 10 rows

Evaluate on testdata

test_result = log_reg.evaluate(test_df).predictions
test_result.show(3)
+--------------------+------+--------------------+--------------------+----------+
|            features|Status|       rawPrediction|         probability|prediction|
+--------------------+------+--------------------+--------------------+----------+
|(8,[0,2,5,7],[1.0...|     0|[5.90239719539664...|[0.99727456260972...|       0.0|
|(8,[0,2,5,7],[1.0...|     0|[5.90239719539664...|[0.99727456260972...|       0.0|
|(8,[0,2,5,7],[1.0...|     0|[5.14907138147437...|[0.99422870848783...|       0.0|
+--------------------+------+--------------------+--------------------+----------+
only showing top 3 rows

Accuracy

A c c u r a c y = ( T P + T N ) ( T P + F P + T N + F N ) Accuracy = \frac{(TP + TN)}{(TP+FP+TN+FN)}

tp = test_result[(test_result.Status == 1) & (test_result.prediction == 1)].count()
tn = test_result[(test_result.Status == 0) & (test_result.prediction == 1)].count()
fp = test_result[(test_result.Status == 0) & (test_result.prediction == 1)].count()
fn = test_result[(test_result.Status == 1) & (test_result.prediction == 0)].count()
# Accuracy

print('test accuracy is : %f'%((tp+tn)/(tp+tn+fp+fn)))
test accuracy is : 0.885116

Recall

print('test recall is : %f'%(tp/(tp+fn)))
print('test precision is : %f'%(tp/(tp+fp)))
test recall is : 0.934393
test precision is : 0.935918

猜你喜欢

转载自blog.csdn.net/u014281392/article/details/89501105