版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u014281392/article/details/89501105
Logistic Regression
Although it is used for classification, it’s still called logistic regression .This is due to the linear regression equations still operate to find the relationship between input variable and the target variables. the main distinction between linear and logistic regression is that we use the some sort of nonlinear function to conver the output , and restrict it between 0 and 1.
虽然它是用来分类的,但它仍然被称为逻辑回归,这是因为线性回归方程仍然可以找到输入变量和目标变量之间的关系。线性回归和逻辑回归的主要区别在于,我们使用某种非线性函数来转换输出,并将其限制在0和1之间。
logisitc regression = nonlinear function(linear regression)
linear regression :
nonlinear function:
logistic regression :
Model Evaluation
- True Positives
- Actual calss: 1
- ML Model Prediction Class: 1
- True Negatives
- Actual Class: 0
- ML Model Prediction Class: 0
- False Positives
- Actual Class: 0
- ML Model Prediction Class: 1
- False Negatives
- Actual Class: 1
- ML Model Prediction Class: 0
Accuracy
Recall
Precision
F1 Score
ROC Curve
用来确定模型的阈值,平衡准确率和召回率。
Building a logistic regression model
- dataset : 20000行,6列。包含用户的国籍,年龄,使用的搜索引擎信息,浏览网页,是否回头客,等信息,预测顾客是否购买这一行为。
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('lg').getOrCreate()
# load data
df = spark.read.csv('./Data/Log_Reg_dataset.csv', inferSchema=True, header=True)
EDA
print((df.count(), len(df.columns)))
(20000, 6)
df.printSchema()
root
|-- Country: string (nullable = true)
|-- Age: integer (nullable = true)
|-- Repeat_Visitor: integer (nullable = true)
|-- Platform: string (nullable = true)
|-- Web_pages_viewed: integer (nullable = true)
|-- Status: integer (nullable = true)
df.select('Country', 'Platform').show(5)
+---------+--------+
| Country|Platform|
+---------+--------+
| India| Yahoo|
| Brazil| Yahoo|
| Brazil| Google|
|Indonesia| Bing|
| Malaysia| Google|
+---------+--------+
only showing top 5 rows
df.select('Age', 'Repeat_Visitor', 'Web_pages_viewed', 'Status').describe().show()
+-------+-----------------+-----------------+-----------------+------------------+
|summary| Age| Repeat_Visitor| Web_pages_viewed| Status|
+-------+-----------------+-----------------+-----------------+------------------+
| count| 20000| 20000| 20000| 20000|
| mean| 28.53955| 0.5029| 9.5533| 0.5|
| stddev|7.888912950773227|0.500004090187782|6.073903499824976|0.5000125004687693|
| min| 17| 0| 1| 0|
| max| 111| 1| 29| 1|
+-------+-----------------+-----------------+-----------------+------------------+
# count country
df.groupBy('Country').count().show()
+---------+-----+
| Country|count|
+---------+-----+
| Malaysia| 1218|
| India| 4018|
|Indonesia|12178|
| Brazil| 2586|
+---------+-----+
df.groupBy('Platform').count().show()
+--------+-----+
|Platform|count|
+--------+-----+
| Yahoo| 9859|
| Bing| 4360|
| Google| 5781|
+--------+-----+
df.groupBy('Repeat_Visitor').count().show()
+--------------+-----+
|Repeat_Visitor|count|
+--------------+-----+
| 1|10058|
| 0| 9942|
+--------------+-----+
df.groupBy('Status').count().show()
+------+-----+
|Status|count|
+------+-----+
| 1|10000|
| 0|10000|
+------+-----+
df.select('Country', 'Age').groupBy('Country').mean().show()
+---------+------------------+
| Country| avg(Age)|
+---------+------------------+
| Malaysia|27.792282430213465|
| India|27.976854156296664|
|Indonesia| 28.43159796354081|
| Brazil|30.274168600154677|
+---------+------------------+
df.select('Age','Repeat_Visitor','Web_pages_viewed','Status').groupBy('Status').mean().show()
+------+--------+-------------------+---------------------+-----------+
|Status|avg(Age)|avg(Repeat_Visitor)|avg(Web_pages_viewed)|avg(Status)|
+------+--------+-------------------+---------------------+-----------+
| 1| 26.5435| 0.7019| 14.5617| 1.0|
| 0| 30.5356| 0.3039| 4.5449| 0.0|
+------+--------+-------------------+---------------------+-----------+
Feature Engineering
- 把categorical变量转换为数值
- 把输入特征合并到一列
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import OneHotEncoder
# label encode
platform_indexer = StringIndexer(inputCol='Platform', outputCol='Platform_num').fit(df)
df = platform_indexer.transform(df)
df.show(5)
+---------+---+--------------+--------+----------------+------+------------+
| Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_num|
+---------+---+--------------+--------+----------------+------+------------+
| India| 41| 1| Yahoo| 21| 1| 0.0|
| Brazil| 28| 1| Yahoo| 5| 0| 0.0|
| Brazil| 40| 0| Google| 3| 0| 1.0|
|Indonesia| 31| 1| Bing| 15| 1| 2.0|
| Malaysia| 32| 0| Google| 15| 1| 1.0|
+---------+---+--------------+--------+----------------+------+------------+
only showing top 5 rows
# one-hot encode
platform_onehoter = OneHotEncoder(inputCol='Platform_num', outputCol='platform_vector')
df = platform_onehoter.transform(df)
df.show(3)
+-------+---+--------------+--------+----------------+------+------------+---------------+
|Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_num|platform_vector|
+-------+---+--------------+--------+----------------+------+------------+---------------+
| India| 41| 1| Yahoo| 21| 1| 0.0| (2,[0],[1.0])|
| Brazil| 28| 1| Yahoo| 5| 0| 0.0| (2,[0],[1.0])|
| Brazil| 40| 0| Google| 3| 0| 1.0| (2,[1],[1.0])|
+-------+---+--------------+--------+----------------+------+------------+---------------+
only showing top 3 rows
- (2, [0], [1.0]) : 0, 1
- (2, [1], [1.0]) : 1, 1
- (2, [], []) : 0, 0
- 这种表示,节省内存,计算更快
# label encode country
country_indexer = StringIndexer(inputCol='Country', outputCol='Country_num').fit(df)
df = country_indexer.transform(df)
# one-hot encode
country_onehoter = OneHotEncoder(inputCol='Country_num', outputCol='Country_vector')
df = country_onehoter.transform(df)
df.select(['Country', 'Country_Num', 'Country_vector']).show(3, False)
+-------+-----------+--------------+
|Country|Country_Num|Country_vector|
+-------+-----------+--------------+
|India |1.0 |(3,[1],[1.0]) |
|Brazil |2.0 |(3,[2],[1.0]) |
|Brazil |2.0 |(3,[2],[1.0]) |
+-------+-----------+--------------+
only showing top 3 rows
df_assembler = VectorAssembler(inputCols=['platform_vector', 'Country_vector', 'Age', 'Repeat_Visitor', 'Web_pages_viewed'],
outputCol='features')
df = df_assembler.transform(df)
df.show(3)
+-------+---+--------------+--------+----------------+------+------------+---------------+-----------+--------------+--------------------+
|Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_num|platform_vector|Country_num|Country_vector| features|
+-------+---+--------------+--------+----------------+------+------------+---------------+-----------+--------------+--------------------+
| India| 41| 1| Yahoo| 21| 1| 0.0| (2,[0],[1.0])| 1.0| (3,[1],[1.0])|[1.0,0.0,0.0,1.0,...|
| Brazil| 28| 1| Yahoo| 5| 0| 0.0| (2,[0],[1.0])| 2.0| (3,[2],[1.0])|[1.0,0.0,0.0,0.0,...|
| Brazil| 40| 0| Google| 3| 0| 1.0| (2,[1],[1.0])| 2.0| (3,[2],[1.0])|(8,[1,4,5,7],[1.0...|
+-------+---+--------------+--------+----------------+------+------------+---------------+-----------+--------------+--------------------+
only showing top 3 rows
df.select(['features', 'Status']).show(5, False)
+-----------------------------------+------+
|features |Status|
+-----------------------------------+------+
|[1.0,0.0,0.0,1.0,0.0,41.0,1.0,21.0]|1 |
|[1.0,0.0,0.0,0.0,1.0,28.0,1.0,5.0] |0 |
|(8,[1,4,5,7],[1.0,1.0,40.0,3.0]) |0 |
|(8,[2,5,6,7],[1.0,31.0,1.0,15.0]) |1 |
|(8,[1,5,7],[1.0,32.0,15.0]) |1 |
+-----------------------------------+------+
only showing top 5 rows
splitting the dataset
data_set = df.select(['features', 'Status'])
train_df, test_df = data_set.randomSplit([0.75, 0.25])
print(' train_df shape : (%d , %d)'%(train_df.count(), len(train_df.columns)))
print(' test_df shape: :(%d , %d)'%(test_df.count(), len(test_df.columns)))
train_df shape : (15024 , 2)
test_df shape: :(4976 , 2)
Train Logistic Regression Model
from pyspark.ml.classification import LogisticRegression
log_reg = LogisticRegression(labelCol = 'Status').fit(train_df)
train_pred = log_reg.evaluate(train_df).predictions
train_pred.filter(train_pred['Status'] == 1).filter(train_pred['prediction'] == 1).select(['Status', 'prediction', 'probability']).show(10, False)
+------+----------+----------------------------------------+
|Status|prediction|probability |
+------+----------+----------------------------------------+
|1 |1.0 |[0.2936888208146831,0.7063111791853168] |
|1 |1.0 |[0.2936888208146831,0.7063111791853168] |
|1 |1.0 |[0.16371245468320667,0.8362875453167934]|
|1 |1.0 |[0.16371245468320667,0.8362875453167934]|
|1 |1.0 |[0.16371245468320667,0.8362875453167934]|
|1 |1.0 |[0.16371245468320667,0.8362875453167934]|
|1 |1.0 |[0.08438651069737801,0.9156134893026219]|
|1 |1.0 |[0.08438651069737801,0.9156134893026219]|
|1 |1.0 |[0.08438651069737801,0.9156134893026219]|
|1 |1.0 |[0.04158614711493927,0.9584138528850608]|
+------+----------+----------------------------------------+
only showing top 10 rows
Evaluate on testdata
test_result = log_reg.evaluate(test_df).predictions
test_result.show(3)
+--------------------+------+--------------------+--------------------+----------+
| features|Status| rawPrediction| probability|prediction|
+--------------------+------+--------------------+--------------------+----------+
|(8,[0,2,5,7],[1.0...| 0|[5.90239719539664...|[0.99727456260972...| 0.0|
|(8,[0,2,5,7],[1.0...| 0|[5.90239719539664...|[0.99727456260972...| 0.0|
|(8,[0,2,5,7],[1.0...| 0|[5.14907138147437...|[0.99422870848783...| 0.0|
+--------------------+------+--------------------+--------------------+----------+
only showing top 3 rows
Accuracy
tp = test_result[(test_result.Status == 1) & (test_result.prediction == 1)].count()
tn = test_result[(test_result.Status == 0) & (test_result.prediction == 1)].count()
fp = test_result[(test_result.Status == 0) & (test_result.prediction == 1)].count()
fn = test_result[(test_result.Status == 1) & (test_result.prediction == 0)].count()
# Accuracy
print('test accuracy is : %f'%((tp+tn)/(tp+tn+fp+fn)))
test accuracy is : 0.885116
Recall
print('test recall is : %f'%(tp/(tp+fn)))
print('test precision is : %f'%(tp/(tp+fp)))
test recall is : 0.934393
test precision is : 0.935918