人工智能AI课 推荐算法实战个性化电商广告推荐系统

个性化电商广告推荐系统介绍

1.1 数据集介绍

  • Ali_Display_Ad_Click是阿里巴巴提供的一个淘宝展示广告点击率预估数据集

    数据集来源:天池竞赛

  • 原始样本骨架raw_sample

    淘宝网站中随机抽样了114万用户8天内的广告展示/点击日志(2600万条记录),构成原始的样本骨架。 字段说明如下:

    1. user_id:脱敏过的用户ID;
    2. adgroup_id:脱敏过的广告单元ID;
    3. time_stamp:时间戳;
    4. pid:资源位;
    5. noclk:为1代表没有点击;为0代表点击;
    6. clk:为0代表没有点击;为1代表点击;

    用前面7天的做训练样本(20170506-20170512),用第8天的做测试样本(20170513)

  • 广告基本信息表ad_feature

    本数据集涵盖了raw_sample中全部广告的基本信息(约80万条目)。字段说明如下:

    1. adgroup_id:脱敏过的广告ID;
    2. cate_id:脱敏过的商品类目ID;
    3. campaign_id:脱敏过的广告计划ID;
    4. customer_id: 脱敏过的广告主ID;
    5. brand_id:脱敏过的品牌ID;
    6. price: 宝贝的价格

    其中一个广告ID对应一个商品(宝贝),一个宝贝属于一个类目,一个宝贝属于一个品牌。

  • 用户基本信息表user_profile

    本数据集涵盖了raw_sample中全部用户的基本信息(约100多万用户)。字段说明如下:

    1. userid:脱敏过的用户ID;
    2. cms_segid:微群ID;
    3. cms_group_id:cms_group_id;
    4. final_gender_code:性别 1:男,2:女;
    5. age_level:年龄层次; 1234
    6. pvalue_level:消费档次,1:低档,2:中档,3:高档;
    7. shopping_level:购物深度,1:浅层用户,2:中度用户,3:深度用户
    8. occupation:是否大学生 ,1:是,0:否
    9. new_user_class_level:城市层级
  • 用户的行为日志behavior_log

    本数据集涵盖了raw_sample中全部用户22天内的购物行为(共七亿条记录)。字段说明如下:

    user:脱敏过的用户ID;
    time_stamp:时间戳;
    btag:行为类型, 包括以下四种:
    ​ 类型 | 说明
    ​ pv | 浏览
    ​ cart | 加入购物车
    ​ fav | 喜欢
    ​ buy | 购买
    cate_id:脱敏过的商品类目id;
    brand_id: 脱敏过的品牌id;
    这里以user + time_stamp为key,会有很多重复的记录;这是因为我们的不同的类型的行为数据是不同部门记录的,在打包到一起的时候,实际上会有小的偏差(即两个一样的time_stamp实际上是差异比较小的两个时间)

1.2 项目效果展示

1.3 项目实现分析

  • 主要包括

    • 一份广告点击的样本数据raw_sample.csv:体现的是用户对不同位置广告点击、没点击的情况
    • 一份广告基本信息数据ad_feature.csv:体现的是每个广告的类目(id)、品牌(id)、价格特征
    • 一份用户基本信息数据user_profile.csv:体现的是用户群组、性别、年龄、消费购物档次、所在城市级别等特征
    • 一份用户行为日志数据behavior_log.csv:体现用户对商品类目(id)、品牌(id)的浏览、加购物车、收藏、购买等信息

    我们是在对非搜索类型的广告进行点击率预测和推荐(没有搜索词、没有广告的内容特征信息)

    1. 推荐业务处理主要流程: 召回 ===> 排序 ===> 过滤
      • 离线处理业务流
        • raw_sample.csv ==> 历史样本数据
        • ad_feature.csv ==> 广告特征数据
        • user_profile.csv ==> 用户特征数据
        • raw_sample.csv + ad_feature.csv + user_profile.csv ==> CTR点击率预测模型
        • behavior_log.csv ==> 评分数据 ==> user-cate/brand评分数据 ==> 协同过滤 ==> top-N cate/brand ==> 关联广告
        • 协同过滤召回 ==> top-N cate/brand ==> 关联对应的广告完成召回
      • 在线处理业务流
        • 数据处理部分:
          • 实时行为日志 ==> 实时特征 ==> 缓存
          • 实时行为日志 ==> 实时商品类别/品牌 ==> 实时广告召回集 ==> 缓存
        • 推荐任务部分:
          • CTR点击率预测模型 + 广告/用户特征(缓存) + 对应的召回集(缓存) ==> 点击率排序 ==> top-N 广告推荐结果
    2. 涉及技术:Flume、Kafka、Spark-streming\HDFS、Spark SQL、Spark ML、Redis
      • Flume:日志数据收集
      • Kafka:实时日志数据处理队列
      • HDFS:存储数据
      • Spark SQL:离线处理
      • Spark ML:模型训练
      • Redis:缓存

1.4 点击率预测(CTR–Click-Through-Rate)概念

  • 电商广告推荐通常使用广告点击率(CTR–Click-Through-Rate)预测来实现

    点击率预测 VS 推荐算法

    点击率预测需要给出精准的点击概率,比如广告A点击率0.5%、广告B的点击率0.12%等;而推荐算法很多时候只需要得出一个最优的次序A>B>C即可。

    点击率预测使用的算法通常是如逻辑回归(Logic Regression)这样的机器学习算法,而推荐算法则是一些基于协同过滤推荐、基于内容的推荐等思想实现的算法

    点击率 VS 转化率

    点击率预测是对每次广告的点击情况做出预测,可以判定这次为点击或不点击,也可以给出点击或不点击的概率

    转化率指的是从状态A进入到状态B的概率,电商的转化率通常是指到达网站后,进而有成交记录的用户比率,如用户成交量/用户访问量

    搜索和非搜索广告点击率预测的区别

    搜索中有很强的搜索信号-“查询词(Query)”,查询词和广告内容的匹配程度很大程度影响了点击概率,搜索广告的点击率普遍较高

    非搜索广告(例如展示广告,信息流广告)的点击率的计算很多就来源于用户的兴趣和广告自身的特征,以及上下文环境。通常好位置能达到百分之几的点击率。对于很多底部的广告,点击率非常低,常常是千分之几,甚至更低

二 根据用户行为数据创建ALS模型并召回商品

2.0 用户行为数据拆分

  • 方便练习可以对数据做拆分处理

    • pandas的数据分批读取 chunk 厚厚的一块 相当大的数量或部分
    import pandas as pd
    reader = pd.read_csv('behavior_log.csv',chunksize=100,iterator=True)
    count = 0;
    for chunk in reader:
        count += 1
        if count ==1:
            chunk.to_csv('test4.csv',index = False)
        elif count>1 and count<1000:
            chunk.to_csv('test4.csv',index = False, mode = 'a',header = False)
        else:
            break
    pd.read_csv('test4.csv')
    

2.1 预处理behavior_log数据集

  • 创建spark session
import os
# 配置spark driver和pyspark运行时,所使用的python解释器路径
PYSPARK_PYTHON = "/home/hadoop/miniconda3/envs/datapy365spark23/bin/python"
JAVA_HOME='/home/hadoop/app/jdk1.8.0_191'
# 当存在多个版本时,不指定很可能会导致出错
os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
os.environ['JAVA_HOME']=JAVA_HOME
# spark配置信息
from pyspark import SparkConf
from pyspark.sql import SparkSession

SPARK_APP_NAME = "preprocessingBehaviorLog"
SPARK_URL = "spark://192.168.199.188:7077"

conf = SparkConf()    # 创建spark config对象
config = (
	("spark.app.name", SPARK_APP_NAME),    # 设置启动的spark的app名称,没有提供,将随机产生一个名称
	("spark.executor.memory", "6g"),    # 设置该app启动时占用的内存用量,默认1g
	("spark.master", SPARK_URL),    # spark master的地址
    ("spark.executor.cores", "4"),    # 设置spark executor使用的CPU核心数
    # 以下三项配置,可以控制执行器数量
#     ("spark.dynamicAllocation.enabled", True),
#     ("spark.dynamicAllocation.initialExecutors", 1),    # 1个执行器
#     ("spark.shuffle.service.enabled", True)
# 	('spark.sql.pivotMaxValues', '99999'),  # 当需要pivot DF,且值很多时,需要修改,默认是10000
)
# 查看更详细配置及说明:https://spark.apache.org/docs/latest/configuration.html

conf.setAll(config)

# 利用config对象,创建spark session
spark = SparkSession.builder.config(conf=conf).getOrCreate()
  • 从hdfs中加载csv文件为DataFrame
# 从hdfs加载CSV文件为DataFrame
df = spark.read.csv("hdfs://localhost:9000/datasets/behavior_log.csv", header=True)
df.show()    # 查看dataframe,默认显示前20条
# 大致查看一下数据类型
df.printSchema()    # 打印当前dataframe的结构

显示结果:

+------+----------+----+-----+------+
|  user|time_stamp|btag| cate| brand|
+------+----------+----+-----+------+
|558157|1493741625|  pv| 6250| 91286|
|558157|1493741626|  pv| 6250| 91286|
|558157|1493741627|  pv| 6250| 91286|
|728690|1493776998|  pv|11800| 62353|
|332634|1493809895|  pv| 1101|365477|
|857237|1493816945|  pv| 1043|110616|
|619381|1493774638|  pv|  385|428950|
|467042|1493772641|  pv| 8237|301299|
|467042|1493772644|  pv| 8237|301299|
|991528|1493780710|  pv| 7270|274795|
|991528|1493780712|  pv| 7270|274795|
|991528|1493780712|  pv| 7270|274795|
|991528|1493780712|  pv| 7270|274795|
|991528|1493780714|  pv| 7270|274795|
|991528|1493780765|  pv| 7270|274795|
|991528|1493780714|  pv| 7270|274795|
|991528|1493780765|  pv| 7270|274795|
|991528|1493780764|  pv| 7270|274795|
|991528|1493780633|  pv| 7270|274795|
|991528|1493780764|  pv| 7270|274795|
+------+----------+----+-----+------+
only showing top 20 rows

root
 |-- user: string (nullable = true)
 |-- time_stamp: string (nullable = true)
 |-- btag: string (nullable = true)
 |-- cate: string (nullable = true)
 |-- brand: string (nullable = true)
  • 从hdfs加载数据为dataframe,并设置结构
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType
# 构建结构对象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("timestamp", LongType()),
    StructField("btag", StringType()),
    StructField("cateId", IntegerType()),
    StructField("brandId", IntegerType())
])
# 从hdfs加载数据为dataframe,并设置结构
behavior_log_df = spark.read.csv("hdfs://localhost:8020/datasets/behavior_log.csv", header=True, schema=schema)
behavior_log_df.show()
behavior_log_df.count() 

显示结果:

+------+----------+----+------+-------+
|userId| timestamp|btag|cateId|brandId|
+------+----------+----+------+-------+
|558157|1493741625|  pv|  6250|  91286|
|558157|1493741626|  pv|  6250|  91286|
|558157|1493741627|  pv|  6250|  91286|
|728690|1493776998|  pv| 11800|  62353|
|332634|1493809895|  pv|  1101| 365477|
|857237|1493816945|  pv|  1043| 110616|
|619381|1493774638|  pv|   385| 428950|
|467042|1493772641|  pv|  8237| 301299|
|467042|1493772644|  pv|  8237| 301299|
|991528|1493780710|  pv|  7270| 274795|
|991528|1493780712|  pv|  7270| 274795|
|991528|1493780712|  pv|  7270| 274795|
|991528|1493780712|  pv|  7270| 274795|
|991528|1493780714|  pv|  7270| 274795|
|991528|1493780765|  pv|  7270| 274795|
|991528|1493780714|  pv|  7270| 274795|
|991528|1493780765|  pv|  7270| 274795|
|991528|1493780764|  pv|  7270| 274795|
|991528|1493780633|  pv|  7270| 274795|
|991528|1493780764|  pv|  7270| 274795|
+------+----------+----+------+-------+
only showing top 20 rows

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- btag: string (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
  • 分析数据集字段的类型和格式
    • 查看是否有空值
    • 查看每列数据的类型
    • 查看每列数据的类别情况
print("查看userId的数据情况:", behavior_log_df.groupBy("userId").count().count())
# 约113w用户
#注意:behavior_log_df.groupBy("userId").count()  返回的是一个dataframe,这里的count计算的是每一个分组的个数,但当前还没有进行计算
# 当调用df.count()时才开始进行计算,这里的count计算的是dataframe的条目数,也就是共有多少个分组
查看user的数据情况: 1136340
print("查看btag的数据情况:", behavior_log_df.groupBy("btag").count().collect())    # collect会把计算结果全部加载到内存,谨慎使用
# 只有四种类型数据:pv、fav、cart、buy
# 这里由于类型只有四个,所以直接使用collect,把数据全部加载出来
查看btag的数据情况: [Row(btag='buy', count=9115919), Row(btag='fav', count=9301837), Row(btag='cart', count=15946033), Row(btag='pv', count=688904345)]
print("查看cateId的数据情况:", behavior_log_df.groupBy("cateId").count().count())
# 约12968类别id
查看cateId的数据情况: 12968
print("查看brandId的数据情况:", behavior_log_df.groupBy("brandId").count().count())
# 约460561品牌id
查看brandId的数据情况: 460561
print("判断数据是否有空值:", behavior_log_df.count(), behavior_log_df.dropna().count())
# 约7亿条目723268134 723268134
# 本数据集无空值条目,可放心处理
判断数据是否有空值: 723268134 723268134
  • pivot透视操作,把某列里的字段值转换成行并进行聚合运算(pyspark.sql.GroupedData.pivot)
    • 如果透视的字段中的不同属性值超过10000个,则需要设置spark.sql.pivotMaxValues,否则计算过程中会出现错误。文档介绍
# 统计每个用户对各类商品的pv、fav、cart、buy数量
cate_count_df = behavior_log_df.groupBy(behavior_log_df.userId, behavior_log_df.cateId).pivot("btag",["pv","fav","cart","buy"]).count()
cate_count_df.printSchema()    # 此时还没有开始计算

显示效果:

root
 |-- userId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- pv: long (nullable = true)
 |-- fav: long (nullable = true)
 |-- cart: long (nullable = true)
 |-- buy: long (nullable = true)
  • 统计每个用户对各个品牌的pv、fav、cart、buy数量并保存结果
# 统计每个用户对各个品牌的pv、fav、cart、buy数量
brand_count_df = behavior_log_df.groupBy(behavior_log_df.userId, behavior_log_df.brandId).pivot("btag",["pv","fav","cart","buy"]).count()
# brand_count_df.show()    # 同上
# 113w * 46w
# 由于运算时间比较长,所以这里先将结果存储起来,供后续其他操作使用
# 写入数据时才开始计算
cate_count_df.write.csv("hdfs://localhost:9000/preprocessing_dataset/cate_count.csv", header=True)
brand_count_df.write.csv("hdfs://localhost:9000/preprocessing_dataset/brand_count.csv", header=True)

2.2 根据用户对类目偏好打分训练ALS模型

  • 根据您统计的次数 + 打分规则 ==> 偏好打分数据集 ==> ALS模型
# spark ml的模型训练是基于内存的,如果数据过大,内存空间小,迭代次数过多的化,可能会造成内存溢出,报错
# 设置Checkpoint的话,会把所有数据落盘,这样如果异常退出,下次重启后,可以接着上次的训练节点继续运行
# 但该方法其实指标不治本,因为无法防止内存溢出,所以还是会报错
# 如果数据量大,应考虑的是增加内存、或限制迭代次数和训练数据量级等
spark.sparkContext.setCheckpointDir("hdfs://localhost:8020/checkPoint/")
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType

# 构建结构对象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("cateId", IntegerType()),
    StructField("pv", IntegerType()),
    StructField("fav", IntegerType()),
    StructField("cart", IntegerType()),
    StructField("buy", IntegerType())
])

# 从hdfs加载CSV文件
cate_count_df = spark.read.csv("hdfs://localhost:9000/preprocessing_dataset/cate_count.csv", header=True, schema=schema)
cate_count_df.printSchema()
cate_count_df.first()    # 第一行数据

显示结果:

root
 |-- userId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- pv: integer (nullable = true)
 |-- fav: integer (nullable = true)
 |-- cart: integer (nullable = true)
 |-- buy: integer (nullable = true)

Row(userId=1061650, cateId=4520, pv=2326, fav=None, cart=53, buy=None)
  • 处理每一行数据:r表示row对象
def process_row(r):
    # 处理每一行数据:r表示row对象
    
    # 偏好评分规则:
	#     m: 用户对应的行为次数
    #     该偏好权重比例,次数上限仅供参考,具体数值应根据产品业务场景权衡
	#     pv: if m<=20: score=0.2*m; else score=4
	#     fav: if m<=20: score=0.4*m; else score=8
	#     cart: if m<=20: score=0.6*m; else score=12
	#     buy: if m<=20: score=1*m; else score=20
    
    # 注意这里要全部设为浮点数,spark运算时对类型比较敏感,要保持数据类型都一致
	pv_count = r.pv if r.pv else 0.0
	fav_count = r.fav if r.fav else 0.0
	cart_count = r.cart if r.cart else 0.0
	buy_count = r.buy if r.buy else 0.0

	pv_score = 0.2*pv_count if pv_count<=20 else 4.0
	fav_score = 0.4*fav_count if fav_count<=20 else 8.0
	cart_score = 0.6*cart_count if cart_count<=20 else 12.0
	buy_score = 1.0*buy_count if buy_count<=20 else 20.0

	rating = pv_score + fav_score + cart_score + buy_score
	# 返回用户ID、分类ID、用户对分类的偏好打分
	return r.userId, r.cateId, rating
  • 返回一个PythonRDD类型
# 返回一个PythonRDD类型,此时还没开始计算
cate_count_df.rdd.map(process_row).toDF(["userId", "cateId", "rating"])

显示结果:

DataFrame[userId: bigint, cateId: bigint, rating: double]
  • 用户对商品类别的打分数据
# 用户对商品类别的打分数据
# map返回的结果是rdd类型,需要调用toDF方法转换为Dataframe
cate_rating_df = cate_count_df.rdd.map(process_row).toDF(["userId", "cateId", "rating"])
# 注意:toDF不是每个rdd都有的方法,仅局限于此处的rdd
# 可通过该方法获得 user-cate-matrix
# 但由于cateId字段过多,这里运算量比很大,机器内存要求很高才能执行,否则无法完成任务
# 请谨慎使用

# 但好在我们训练ALS模型时,不需要转换为user-cate-matrix,所以这里可以不用运行
# cate_rating_df.groupBy("userId").povit("cateId").min("rating")
# 用户对类别的偏好打分数据
cate_rating_df

显示结果:

DataFrame[userId: bigint, cateId: bigint, rating: double]
  • 通常如果USER-ITEM打分数据应该是通过一下方式进行处理转换为USER-ITEM-MATRIX

但这里我们将使用的Spark的ALS模型进行CF推荐,因此注意这里数据输入不需要提前转换为矩阵,直接是 USER-ITEM-RATE的数据

  • 基于Spark的ALS隐因子模型进行CF评分预测

    • ALS的意思是交替最小二乘法(Alternating Least Squares),是Spark2.*中加入的进行基于模型的协同过滤(model-based CF)的推荐系统算法。

      同SVD,它也是一种矩阵分解技术,对数据进行降维处理。

    • 详细使用方法:pyspark.ml.recommendation.ALS

    • 注意:由于数据量巨大,因此这里也不考虑基于内存的CF算法

      参考:为什么Spark中只有ALS

# 使用pyspark中的ALS矩阵分解方法实现CF评分预测
# 文档地址:https://spark.apache.org/docs/2.2.2/api/python/pyspark.ml.html?highlight=vectors#module-pyspark.ml.recommendation
from pyspark.ml.recommendation import ALS   # ml:dataframe, mllib:rdd

# 利用打分数据,训练ALS模型
als = ALS(userCol='userId', itemCol='cateId', ratingCol='rating', checkpointInterval=5)

# 此处训练时间较长
model = als.fit(cate_rating_df)
# model.recommendForAllUsers(N) 给所有用户推荐TOP-N个物品
ret = model.recommendForAllUsers(3)
# 由于是给所有用户进行推荐,此处运算时间也较长
ret.show()
# 推荐结果存放在recommendations列中,
ret.select("recommendations").show()

显示结果:

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|   148|[[3347, 12.547271...|
|   463|[[1610, 9.250818]...|
|   471|[[1610, 10.246621...|
|   496|[[1610, 5.162216]...|
|   833|[[5607, 9.065482]...|
|  1088|[[104, 6.886987],...|
|  1238|[[5631, 14.51981]...|
|  1342|[[5720, 10.89842]...|
|  1580|[[5731, 8.466453]...|
|  1591|[[1610, 12.835257...|
|  1645|[[1610, 11.968531...|
|  1829|[[1610, 17.576496...|
|  1959|[[1610, 8.353473]...|
|  2122|[[1610, 12.652732...|
|  2142|[[1610, 12.48068]...|
|  2366|[[1610, 11.904813...|
|  2659|[[5607, 11.699315...|
|  2866|[[1610, 7.752719]...|
|  3175|[[3347, 2.3429515...|
|  3749|[[1610, 3.641833]...|
+------+--------------------+
only showing top 20 rows

+--------------------+
|     recommendations|
+--------------------+
|[[3347, 12.547271...|
|[[1610, 9.250818]...|
|[[1610, 10.246621...|
|[[1610, 5.162216]...|
|[[5607, 9.065482]...|
|[[104, 6.886987],...|
|[[5631, 14.51981]...|
|[[5720, 10.89842]...|
|[[5731, 8.466453]...|
|[[1610, 12.835257...|
|[[1610, 11.968531...|
|[[1610, 17.576496...|
|[[1610, 8.353473]...|
|[[1610, 12.652732...|
|[[1610, 12.48068]...|
|[[1610, 11.904813...|
|[[5607, 11.699315...|
|[[1610, 7.752719]...|
|[[3347, 2.3429515...|
|[[1610, 3.641833]...|
+--------------------+
only showing top 20 rows
  • model.recommendForUserSubset 给部分用户推荐TOP-N个物品
# 注意:recommendForUserSubset API,2.2.2版本中无法使用
dataset = spark.createDataFrame([[1],[2],[3]])
dataset = dataset.withColumnRenamed("_1", "userId")
ret = model.recommendForUserSubset(dataset, 3)

# 只给部分用推荐,运算时间短
ret.show()
ret.collect()    # 注意: collect会将所有数据加载到内存,慎用

显示结果:

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|     1|[[1610, 25.4989],...|
|     3|[[5607, 13.665942...|
|     2|[[5579, 5.9051886...|
+------+--------------------+

[Row(userId=1, recommendations=[Row(cateId=1610, rating=25.498899459838867), Row(cateId=5737, rating=24.901548385620117), Row(cateId=3347, rating=20.736785888671875)]),
 Row(userId=3, recommendations=[Row(cateId=5607, rating=13.665942192077637), Row(cateId=1610, rating=11.770171165466309), Row(cateId=3347, rating=10.35690689086914)]),
 Row(userId=2, recommendations=[Row(cateId=5579, rating=5.90518856048584), Row(cateId=2447, rating=5.624575138092041), Row(cateId=5690, rating=5.2555742263793945)])]
  • transform中提供userId和cateId可以对打分进行预测,利用打分结果排序后
# transform中提供userId和cateId可以对打分进行预测,利用打分结果排序后,同样可以实现TOP-N的推荐
model.transform
# 将模型进行存储
model.save("hdfs://localhost:8020/models/userCateRatingALSModel.obj")
# 测试存储的模型
from pyspark.ml.recommendation import ALSModel
# 从hdfs加载之前存储的模型
als_model = ALSModel.load("hdfs://localhost:8020/models/userCateRatingALSModel.obj")
# model.recommendForAllUsers(N) 给用户推荐TOP-N个物品
result = als_model.recommendForAllUsers(3)
result.show()

显示结果:

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|   148|[[3347, 12.547271...|
|   463|[[1610, 9.250818]...|
|   471|[[1610, 10.246621...|
|   496|[[1610, 5.162216]...|
|   833|[[5607, 9.065482]...|
|  1088|[[104, 6.886987],...|
|  1238|[[5631, 14.51981]...|
|  1342|[[5720, 10.89842]...|
|  1580|[[5731, 8.466453]...|
|  1591|[[1610, 12.835257...|
|  1645|[[1610, 11.968531...|
|  1829|[[1610, 17.576496...|
|  1959|[[1610, 8.353473]...|
|  2122|[[1610, 12.652732...|
|  2142|[[1610, 12.48068]...|
|  2366|[[1610, 11.904813...|
|  2659|[[5607, 11.699315...|
|  2866|[[1610, 7.752719]...|
|  3175|[[3347, 2.3429515...|
|  3749|[[1610, 3.641833]...|
+------+--------------------+
only showing top 20 rows
  • 召回到redis
import redis
host = "192.168.19.8"
port = 6379    
# 召回到redis
def recall_cate_by_cf(partition):
    # 建立redis 连接池
    pool = redis.ConnectionPool(host=host, port=port)
    # 建立redis客户端
    client = redis.Redis(connection_pool=pool)
    for row in partition:
        client.hset("recall_cate", row.userId, [i.cateId for i in row.recommendations])
# 对每个分片的数据进行处理 #mapPartition Transformation   map
# foreachPartition Action操作             foreachRDD
result.foreachPartition(recall_cate_by_cf)

# 注意:这里这是召回的是用户最感兴趣的n个类别
# 总的条目数,查看redis中总的条目数是否一致
result.count()

显示结果:

1136340

2.3 根据用户对品牌偏好打分训练ALS模型

from pyspark.sql.types import StructType, StructField, StringType, IntegerType

schema = StructType([
    StructField("userId", IntegerType()),
    StructField("brandId", IntegerType()),
    StructField("pv", IntegerType()),
    StructField("fav", IntegerType()),
    StructField("cart", IntegerType()),
    StructField("buy", IntegerType())
])
# 从hdfs加载预处理好的品牌的统计数据
brand_count_df = spark.read.csv("hdfs://localhost:8020/preprocessing_dataset/brand_count.csv", header=True, schema=schema)
# brand_count_df.show()
def process_row(r):
    # 处理每一行数据:r表示row对象
    
    # 偏好评分规则:
	#     m: 用户对应的行为次数
    #     该偏好权重比例,次数上限仅供参考,具体数值应根据产品业务场景权衡
	#     pv: if m<=20: score=0.2*m; else score=4
	#     fav: if m<=20: score=0.4*m; else score=8
	#     cart: if m<=20: score=0.6*m; else score=12
	#     buy: if m<=20: score=1*m; else score=20
    
    # 注意这里要全部设为浮点数,spark运算时对类型比较敏感,要保持数据类型都一致
	pv_count = r.pv if r.pv else 0.0
	fav_count = r.fav if r.fav else 0.0
	cart_count = r.cart if r.cart else 0.0
	buy_count = r.buy if r.buy else 0.0

	pv_score = 0.2*pv_count if pv_count<=20 else 4.0
	fav_score = 0.4*fav_count if fav_count<=20 else 8.0
	cart_score = 0.6*cart_count if cart_count<=20 else 12.0
	buy_score = 1.0*buy_count if buy_count<=20 else 20.0

	rating = pv_score + fav_score + cart_score + buy_score
	# 返回用户ID、品牌ID、用户对品牌的偏好打分
	return r.userId, r.brandId, rating
# 用户对品牌的打分数据
brand_rating_df = brand_count_df.rdd.map(process_row).toDF(["userId", "brandId", "rating"])
# brand_rating_df.show()
  • 基于Spark的ALS隐因子模型进行CF评分预测

    • ALS的意思是交替最小二乘法(Alternating Least Squares),是Spark中进行基于模型的协同过滤(model-based CF)的推荐系统算法,也是目前Spark内唯一一个推荐算法。

      同SVD,它也是一种矩阵分解技术,但理论上,ALS在海量数据的处理上要优于SVD。

      更多了解:pyspark.ml.recommendation.ALS

      注意:由于数据量巨大,因此这里不考虑基于内存的CF算法

      参考:为什么Spark中只有ALS

  • 使用pyspark中的ALS矩阵分解方法实现CF评分预测

# 使用pyspark中的ALS矩阵分解方法实现CF评分预测
# 文档地址:https://spark.apache.org/docs/latest/api/python/pyspark.ml.html?highlight=vectors#module-pyspark.ml.recommendation
from pyspark.ml.recommendation import ALS

als = ALS(userCol='userId', itemCol='brandId', ratingCol='rating', checkpointInterval=2)
# 利用打分数据,训练ALS模型
# 此处训练时间较长
model = als.fit(brand_rating_df)
# model.recommendForAllUsers(N) 给用户推荐TOP-N个物品
model.recommendForAllUsers(3).show()
# 将模型进行存储
model.save("hdfs://localhost:9000/models/userBrandRatingModel.obj")
# 测试存储的模型
from pyspark.ml.recommendation import ALSModel
# 从hdfs加载模型
my_model = ALSModel.load("hdfs://localhost:9000/models/userBrandRatingModel.obj")
my_model
# model.recommendForAllUsers(N) 给用户推荐TOP-N个物品
my_model.recommendForAllUsers(3).first()

三 CTR预估数据准备

3.1 分析并预处理raw_sample数据集

# 从HDFS中加载样本数据信息
df = spark.read.csv("hdfs://localhost:9000/datasets/raw_sample.csv", header=True)
df.show()    # 展示数据,默认前20条
df.printSchema()

显示结果:

+------+----------+----------+-----------+------+---+
|  user|time_stamp|adgroup_id|        pid|nonclk|clk|
+------+----------+----------+-----------+------+---+
|581738|1494137644|         1|430548_1007|     1|  0|
|449818|1494638778|         3|430548_1007|     1|  0|
|914836|1494650879|         4|430548_1007|     1|  0|
|914836|1494651029|         5|430548_1007|     1|  0|
|399907|1494302958|         8|430548_1007|     1|  0|
|628137|1494524935|         9|430548_1007|     1|  0|
|298139|1494462593|         9|430539_1007|     1|  0|
|775475|1494561036|         9|430548_1007|     1|  0|
|555266|1494307136|        11|430539_1007|     1|  0|
|117840|1494036743|        11|430548_1007|     1|  0|
|739815|1494115387|        11|430539_1007|     1|  0|
|623911|1494625301|        11|430548_1007|     1|  0|
|623911|1494451608|        11|430548_1007|     1|  0|
|421590|1494034144|        11|430548_1007|     1|  0|
|976358|1494156949|        13|430548_1007|     1|  0|
|286630|1494218579|        13|430539_1007|     1|  0|
|286630|1494289247|        13|430539_1007|     1|  0|
|771431|1494153867|        13|430548_1007|     1|  0|
|707120|1494220810|        13|430548_1007|     1|  0|
|530454|1494293746|        13|430548_1007|     1|  0|
+------+----------+----------+-----------+------+---+
only showing top 20 rows

root
 |-- user: string (nullable = true)
 |-- time_stamp: string (nullable = true)
 |-- adgroup_id: string (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: string (nullable = true)
 |-- clk: string (nullable = true)
  • 分析数据集字段的类型和格式
    • 查看是否有空值
    • 查看每列数据的类型
    • 查看每列数据的类别情况
print("样本数据集总条目数:", df.count())
# 约2600w
print("用户user总数:", df.groupBy("user").count().count())
# 约 114w,略多余日志数据中用户数
print("广告id adgroup_id总数:", df.groupBy("adgroup_id").count().count())
# 约85w
print("广告展示位pid情况:", df.groupBy("pid").count().collect())
# 只有两种广告展示位,占比约为六比四
print("广告点击数据情况clk:", df.groupBy("clk").count().collect())
# 点和不点比率约: 1:20

显示结果:

样本数据集总条目数: 26557961
用户user总数: 1141729
广告id adgroup_id总数: 846811
广告展示位pid情况: [Row(pid='430548_1007', count=16472898), Row(pid='430539_1007', count=10085063)]
广告点击数据情况clk: [Row(clk='0', count=25191905), Row(clk='1', count=1366056)]
  • 使用dataframe.withColumn更改df列数据结构;使用dataframe.withColumnRenamed更改列名称
# 更改表结构,转换为对应的数据类型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, LongType, StringType

# 打印df结构信息
df.printSchema()   
# 更改df表结构:更改列类型和列名称
raw_sample_df = df.\
    withColumn("user", df.user.cast(IntegerType())).withColumnRenamed("user", "userId").\
    withColumn("time_stamp", df.time_stamp.cast(LongType())).withColumnRenamed("time_stamp", "timestamp").\
    withColumn("adgroup_id", df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("pid", df.pid.cast(StringType())).\
    withColumn("nonclk", df.nonclk.cast(IntegerType())).\
    withColumn("clk", df.clk.cast(IntegerType()))
raw_sample_df.printSchema()
raw_sample_df.show()

显示结果:

root
 |-- user: string (nullable = true)
 |-- time_stamp: string (nullable = true)
 |-- adgroup_id: string (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: string (nullable = true)
 |-- clk: string (nullable = true)

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)

+------+----------+---------+-----------+------+---+
|userId| timestamp|adgroupId|        pid|nonclk|clk|
+------+----------+---------+-----------+------+---+
|581738|1494137644|        1|430548_1007|     1|  0|
|449818|1494638778|        3|430548_1007|     1|  0|
|914836|1494650879|        4|430548_1007|     1|  0|
|914836|1494651029|        5|430548_1007|     1|  0|
|399907|1494302958|        8|430548_1007|     1|  0|
|628137|1494524935|        9|430548_1007|     1|  0|
|298139|1494462593|        9|430539_1007|     1|  0|
|775475|1494561036|        9|430548_1007|     1|  0|
|555266|1494307136|       11|430539_1007|     1|  0|
|117840|1494036743|       11|430548_1007|     1|  0|
|739815|1494115387|       11|430539_1007|     1|  0|
|623911|1494625301|       11|430548_1007|     1|  0|
|623911|1494451608|       11|430548_1007|     1|  0|
|421590|1494034144|       11|430548_1007|     1|  0|
|976358|1494156949|       13|430548_1007|     1|  0|
|286630|1494218579|       13|430539_1007|     1|  0|
|286630|1494289247|       13|430539_1007|     1|  0|
|771431|1494153867|       13|430548_1007|     1|  0|
|707120|1494220810|       13|430548_1007|     1|  0|
|530454|1494293746|       13|430548_1007|     1|  0|
+------+----------+---------+-----------+------+---+
only showing top 20 rows
  • 特征选取(Feature Selection)

    • 特征选择就是选择那些靠谱的Feature,去掉冗余的Feature,对于搜索广告,Query关键词和广告的匹配程度很重要;但对于展示广告,广告本身的历史表现,往往是最重要的Feature。

      根据经验,该数据集中,只有广告展示位pid对比较重要,且数据不同数据之间的占比约为6:4,因此pid可以作为一个关键特征

      nonclk和clk在这里是作为目标值,不做为特征

  • 热独编码 OneHotEncode

    • 热独编码是一种经典编码,是使用N位状态寄存器(如0和1)来对N个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候,其中只有一位有效。

      假设有三组特征,分别表示年龄,城市,设备;

      [“男”, “女”][0,1]

      [“北京”, “上海”, “广州”][0,1,2]

      [“苹果”, “小米”, “华为”, “微软”][0,1,2,3]

      传统变化: 对每一组特征,使用枚举类型,从0开始;

      ["男“,”上海“,”小米“]=[ 0,1,1]

      ["女“,”北京“,”苹果“] =[1,0,0]

      传统变化后的数据不是连续的,而是随机分配的,不容易应用在分类器中

      而经过热独编码,数据会变成稀疏的,方便分类器处理:

      ["男“,”上海“,”小米“]=[ 1,0,0,1,0,0,1,0,0]

      ["女“,”北京“,”苹果“] =[0,1,1,0,0,1,0,0,0]

      这样做保留了特征的多样性,但是也要注意如果数据过于稀疏(样本较少、维度过高),其效果反而会变差

  • Spark中使用热独编码

    • 注意:热编码只能对字符串类型的列数据进行处理

      StringIndexer:对指定字符串列数据进行特征处理,如将性别数据“男”、“女”转化为0和1

      OneHotEncoder:对特征列数据,进行热编码,通常需结合StringIndexer一起使用

      Pipeline:让数据按顺序依次被处理,将前一次的处理结果作为下一次的输入

  • 特征处理

'''特征处理'''
'''
pid 资源位。该特征属于分类特征,只有两类取值,因此考虑进行热编码处理即可,分为是否在资源位1、是否在资源位2 两个特征
'''
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# StringIndexer对指定字符串列进行特征处理
stringindexer = StringIndexer(inputCol='pid', outputCol='pid_feature')

# 对处理出来的特征处理列进行,热独编码
encoder = OneHotEncoder(dropLast=False, inputCol='pid_feature', outputCol='pid_value')
# 利用管道对每一个数据进行热独编码处理
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_model = pipeline.fit(raw_sample_df)
new_df = pipeline_model.transform(raw_sample_df)
new_df.show()

显示结果:

+------+----------+---------+-----------+------+---+-----------+-------------+
|userId| timestamp|adgroupId|        pid|nonclk|clk|pid_feature|    pid_value|
+------+----------+---------+-----------+------+---+-----------+-------------+
|581738|1494137644|        1|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|449818|1494638778|        3|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494650879|        4|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494651029|        5|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|399907|1494302958|        8|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|628137|1494524935|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|298139|1494462593|        9|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|775475|1494561036|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|555266|1494307136|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|117840|1494036743|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|739815|1494115387|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|623911|1494625301|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|623911|1494451608|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|421590|1494034144|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|976358|1494156949|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|286630|1494218579|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|286630|1494289247|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|771431|1494153867|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|707120|1494220810|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|530454|1494293746|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
+------+----------+---------+-----------+------+---+-----------+-------------+
only showing top 20 rows

from pyspark.ml.linalg import SparseVector
# 参数:维度、索引列表、值列表
print(SparseVector(4, [1, 3], [3.0, 4.0]))
print(SparseVector(4, [1, 3], [3.0, 4.0]).toArray())
print("*********")
print(new_df.select("pid_value").first())
print(new_df.select("pid_value").first().pid_value.toArray())

显示结果:

(4,[1,3],[3.0,4.0])
[0. 3. 0. 4.]
*********
Row(pid_value=SparseVector(2, {0: 1.0}))
[1. 0.]

  • 查看最大时间
new_df.sort("timestamp", ascending=False).show()
+------+----------+---------+-----------+------+---+-----------+-------------+
|userId| timestamp|adgroupId|        pid|nonclk|clk|pid_feature|    pid_value|
+------+----------+---------+-----------+------+---+-----------+-------------+
|177002|1494691186|   593001|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|243671|1494691186|   600195|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|488527|1494691184|   494312|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|488527|1494691184|   431082|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
| 17054|1494691184|   742741|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
| 17054|1494691184|   756665|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|488527|1494691184|   687854|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|839493|1494691183|   561681|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|704223|1494691183|   624504|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|839493|1494691183|   582235|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|704223|1494691183|   675674|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|628998|1494691180|   618965|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|674444|1494691179|   427579|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|627200|1494691179|   782038|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|627200|1494691179|   420769|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|674444|1494691179|   588664|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|738335|1494691179|   451004|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|627200|1494691179|   817569|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|322244|1494691179|   820018|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|322244|1494691179|   735220|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
+------+----------+---------+-----------+------+---+-----------+-------------+
only showing top 20 rows
# 本样本数据集共计8天数据
# 前七天为训练数据、最后一天为测试数据

from datetime import datetime
datetime.fromtimestamp(1494691186)
print("该时间之前的数据为训练样本,该时间以后的数据为测试样本:", datetime.fromtimestamp(1494691186-24*60*60))

显示结果:

该时间之前的数据为训练样本,该时间以后的数据为测试样本: 2017-05-12 23:59:46

  • 训练样本
# 训练样本:
train_sample = raw_sample_df.filter(raw_sample_df.timestamp<=(1494691186-24*60*60))
print("训练样本个数:")
print(train_sample.count())
# 测试样本
test_sample = raw_sample_df.filter(raw_sample_df.timestamp>(1494691186-24*60*60))
print("测试样本个数:")
print(test_sample.count())

# 注意:还需要加入广告基本特征和用户基本特征才能做程一份完整的样本数据集

显示结果:

训练样本个数:
23249291
测试样本个数:
3308670

3.2 分析并预处理ad_feature数据集

# 从HDFS中加载广告基本信息数据,返回spark dafaframe对象
df = spark.read.csv("hdfs://localhost:9000/datasets/ad_feature.csv", header=True)
df.show()    # 展示数据,默认前20条

显示结果:

+----------+-------+-----------+--------+------+-----+
|adgroup_id|cate_id|campaign_id|customer| brand|price|
+----------+-------+-----------+--------+------+-----+
|     63133|   6406|      83237|       1| 95471|170.0|
|    313401|   6406|      83237|       1| 87331|199.0|
|    248909|    392|      83237|       1| 32233| 38.0|
|    208458|    392|      83237|       1|174374|139.0|
|    110847|   7211|     135256|       2|145952|32.99|
|    607788|   6261|     387991|       6|207800|199.0|
|    375706|   4520|     387991|       6|  NULL| 99.0|
|     11115|   7213|     139747|       9|186847| 33.0|
|     24484|   7207|     139744|       9|186847| 19.0|
|     28589|   5953|     395195|      13|  NULL|428.0|
|     23236|   5953|     395195|      13|  NULL|368.0|
|    300556|   5953|     395195|      13|  NULL|639.0|
|     92560|   5953|     395195|      13|  NULL|368.0|
|    590965|   4284|      28145|      14|454237|249.0|
|    529913|   4284|      70206|      14|  NULL|249.0|
|    546930|   4284|      28145|      14|  NULL|249.0|
|    639794|   6261|      70206|      14| 37004| 89.9|
|    335413|   4284|      28145|      14|  NULL|249.0|
|    794890|   4284|      70206|      14|454237|249.0|
|    684020|   6261|      70206|      14| 37004| 99.0|
+----------+-------+-----------+--------+------+-----+
only showing top 20 rows
# 注意:由于本数据集中存在NULL字样的数据,无法直接设置schema,只能先将NULL类型的数据处理掉,然后进行类型转换

from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替换掉NULL字符串,替换掉
df = df.replace("NULL", "-1")

# 打印df结构信息
df.printSchema()   
# 更改df表结构:更改列类型和列名称
ad_feature_df = df.\
    withColumn("adgroup_id", df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", df.price.cast(FloatType()))
ad_feature_df.printSchema()
ad_feature_df.show()

显示结果:

root
 |-- adgroup_id: string (nullable = true)
 |-- cate_id: string (nullable = true)
 |-- campaign_id: string (nullable = true)
 |-- customer: string (nullable = true)
 |-- brand: string (nullable = true)
 |-- price: string (nullable = true)

root
 |-- adgroupId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- campaignId: integer (nullable = true)
 |-- customerId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
 |-- price: float (nullable = true)

+---------+------+----------+----------+-------+-----+
|adgroupId|cateId|campaignId|customerId|brandId|price|
+---------+------+----------+----------+-------+-----+
|    63133|  6406|     83237|         1|  95471|170.0|
|   313401|  6406|     83237|         1|  87331|199.0|
|   248909|   392|     83237|         1|  32233| 38.0|
|   208458|   392|     83237|         1| 174374|139.0|
|   110847|  7211|    135256|         2| 145952|32.99|
|   607788|  6261|    387991|         6| 207800|199.0|
|   375706|  4520|    387991|         6|     -1| 99.0|
|    11115|  7213|    139747|         9| 186847| 33.0|
|    24484|  7207|    139744|         9| 186847| 19.0|
|    28589|  5953|    395195|        13|     -1|428.0|
|    23236|  5953|    395195|        13|     -1|368.0|
|   300556|  5953|    395195|        13|     -1|639.0|
|    92560|  5953|    395195|        13|     -1|368.0|
|   590965|  4284|     28145|        14| 454237|249.0|
|   529913|  4284|     70206|        14|     -1|249.0|
|   546930|  4284|     28145|        14|     -1|249.0|
|   639794|  6261|     70206|        14|  37004| 89.9|
|   335413|  4284|     28145|        14|     -1|249.0|
|   794890|  4284|     70206|        14| 454237|249.0|
|   684020|  6261|     70206|        14|  37004| 99.0|
+---------+------+----------+----------+-------+-----+
only showing top 20 rows
  • 查看各项数据的特征
print("总广告条数:",df.count())   # 数据条数
_1 = ad_feature_df.groupBy("cateId").count().count()
print("cateId数值个数:", _1)
_2 = ad_feature_df.groupBy("campaignId").count().count()
print("campaignId数值个数:", _2)
_3 = ad_feature_df.groupBy("customerId").count().count()
print("customerId数值个数:", _3)
_4 = ad_feature_df.groupBy("brandId").count().count()
print("brandId数值个数:", _4)
ad_feature_df.sort("price").show()
ad_feature_df.sort("price", ascending=False).show()
print("价格高于1w的条目个数:", ad_feature_df.select("price").filter("price>10000").count())
print("价格低于1的条目个数", ad_feature_df.select("price").filter("price<1").count())

显示结果:

总广告条数: 846811
cateId数值个数: 6769
campaignId数值个数: 423436
customerId数值个数: 255875
brandId数值个数: 99815
+---------+------+----------+----------+-------+-----+
|adgroupId|cateId|campaignId|customerId|brandId|price|
+---------+------+----------+----------+-------+-----+
|   485749|  9970|    352666|    140520|     -1| 0.01|
|    88975|  9996|    198424|    182415|     -1| 0.01|
|   109704| 10539|     59774|     90351| 202710| 0.01|
|    49911|  7032|    129079|    172334|     -1| 0.01|
|   339334|  9994|    310408|    211292| 383023| 0.01|
|     6636|  6703|    392038|     46239| 406713| 0.01|
|    92241|  6130|     72781|    149714|     -1| 0.01|
|    20397| 10539|    410958|     65726|  79971| 0.01|
|   345870|  9995|    179595|    191036|  79971| 0.01|
|    77797|  9086|    218276|     31183|     -1| 0.01|
|    14435|  1136|    135610|     17788|     -1| 0.01|
|    42055|  9994|     43866|    113068| 123242| 0.01|
|    41925|  7032|     85373|    114532|     -1| 0.01|
|    67558|  9995|     90141|     83948|     -1| 0.01|
|   149570|  7043|    126746|    176076|     -1| 0.01|
|   518883|  7185|    403318|     58013|     -1| 0.01|
|     2246|  9996|    413653|     60214| 182966| 0.01|
|   290675|  4824|    315371|    240984|     -1| 0.01|
|   552638| 10305|    403318|     58013|     -1| 0.01|
|    89831| 10539|     90141|     83948| 211816| 0.01|
+---------+------+----------+----------+-------+-----+
only showing top 20 rows

+---------+------+----------+----------+-------+-----------+
|adgroupId|cateId|campaignId|customerId|brandId|      price|
+---------+------+----------+----------+-------+-----------+
|   658722|  1093|    218101|    207754|     -1|      1.0E8|
|   468220|  1093|    270719|    207754|     -1|      1.0E8|
|   179746|  1093|    270027|    102509| 405447|      1.0E8|
|   443295|  1093|     44251|    102509| 300681|      1.0E8|
|    31899|   685|    218918|     31239| 278301|      1.0E8|
|   243384|   685|    218918|     31239| 278301|      1.0E8|
|   554311|  1093|    266086|    207754|     -1|      1.0E8|
|   513942|   745|      8401|     86243|     -1|8.8888888E7|
|   201060|   745|      8401|     86243|     -1|5.5555556E7|
|   289563|   685|     37665|    120847| 278301|      1.5E7|
|    35156|   527|    417722|     72273| 278301|      1.0E7|
|    33756|   527|    416333|     70894|     -1|  9900000.0|
|   335495|   739|    170121|    148946| 326126|  9600000.0|
|   218306|   206|    162394|      4339| 221720|  8888888.0|
|   213567|  7213|    239302|    205612| 406125|  5888888.0|
|   375920|   527|    217512|    148946| 326126|  4760000.0|
|   262215|   527|    132721|     11947| 417898|  3980000.0|
|   154623|   739|    170121|    148946| 326126|  3900000.0|
|   152414|   739|    170121|    148946| 326126|  3900000.0|
|   448651|   527|    422260|     41289| 209959|  3800000.0|
+---------+------+----------+----------+-------+-----------+
only showing top 20 rows

价格高于1w的条目个数: 6527
价格低于1的条目个数 5762

  • 特征选择

    • cateId:脱敏过的商品类目ID;
    • campaignId:脱敏过的广告计划ID;
    • customerId:脱敏过的广告主ID;
    • brandId:脱敏过的品牌ID;

    以上四个特征均属于分类特征,但由于分类值个数均过于庞大,如果去做热独编码处理,会导致数据过于稀疏 且当前我们缺少对这些特征更加具体的信息,(如商品类目具体信息、品牌具体信息等),从而无法对这些特征的数据做聚类、降维处理 因此这里不选取它们作为特征

    而只选取price作为特征数据,因为价格本身是一个统计类型连续数值型数据,且能很好的体现广告的价值属性特征,通常也不需要做其他处理(离散化、归一化、标准化等),所以这里直接将当做特征数据来使用

3.3 分析并预处理user_profile数据集

# 从HDFS加载用户基本信息数据
df = spark.read.csv("hdfs://localhost:8020/csv/user_profile.csv", header=True)
# 发现pvalue_level和new_user_class_level存在空值:(注意此处的null表示空值,而如果是NULL,则往往表示是一个字符串)
# 因此直接利用schema就可以加载进该数据,无需替换null值
df.show()

显示结果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level |
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|   234|        0|           5|                2|        5|        null|             3|         0|                    3|
|   523|        5|           2|                2|        2|           1|             3|         1|                    2|
|   612|        0|           8|                1|        2|           2|             3|         0|                 null|
|  1670|        0|           4|                2|        4|        null|             1|         0|                 null|
|  2545|        0|          10|                1|        4|        null|             3|         0|                 null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                    2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                    2|
|  6211|        0|           9|                1|        3|        null|             3|         0|                    2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                    4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                    1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                    2|
|  9293|        0|           5|                2|        5|        null|             3|         0|                    4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                    2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                    2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10812|        0|           4|                2|        4|        null|             2|         0|                 null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10996|        0|           5|                2|        5|        null|             3|         0|                    4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                    3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                    4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
# 注意:这里的null会直接被pyspark识别为None数据,也就是na数据,所以这里可以直接利用schema导入数据

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType

# 构建表结构schema对象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("cms_segid", IntegerType()),
    StructField("cms_group_id", IntegerType()),
    StructField("final_gender_code", IntegerType()),
    StructField("age_level", IntegerType()),
    StructField("pvalue_level", IntegerType()),
    StructField("shopping_level", IntegerType()),
    StructField("occupation", IntegerType()),
    StructField("new_user_class_level", IntegerType())
])
# 利用schema从hdfs加载
user_profile_df = spark.read.csv("hdfs://localhost:8020/csv/user_profile.csv", header=True, schema=schema)
user_profile_df.printSchema()
user_profile_df.show()

显示结果:

root
 |-- userId: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: integer (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: integer (nullable = true)

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   234|        0|           5|                2|        5|        null|             3|         0|                   3|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|
|   612|        0|           8|                1|        2|           2|             3|         0|                null|
|  1670|        0|           4|                2|        4|        null|             1|         0|                null|
|  2545|        0|          10|                1|        4|        null|             3|         0|                null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|
|  6211|        0|           9|                1|        3|        null|             3|         0|                   2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|
|  9293|        0|           5|                2|        5|        null|             3|         0|                   4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                null|
| 10812|        0|           4|                2|        4|        null|             2|         0|                null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                null|
| 10996|        0|           5|                2|        5|        null|             3|         0|                   4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 20 rows
  • 显示特征情况
print("分类特征值个数情况: ")
print("cms_segid: ", user_profile_df.groupBy("cms_segid").count().count())
print("cms_group_id: ", user_profile_df.groupBy("cms_group_id").count().count())
print("final_gender_code: ", user_profile_df.groupBy("final_gender_code").count().count())
print("age_level: ", user_profile_df.groupBy("age_level").count().count())
print("shopping_level: ", user_profile_df.groupBy("shopping_level").count().count())
print("occupation: ", user_profile_df.groupBy("occupation").count().count())

print("含缺失值的特征情况: ")
user_profile_df.groupBy("pvalue_level").count().show()
user_profile_df.groupBy("new_user_class_level").count().show()

t_count = user_profile_df.count()
pl_na_count = t_count - user_profile_df.dropna(subset=["pvalue_level"]).count()
print("pvalue_level的空值情况:", pl_na_count, "空值占比:%0.2f%%"%(pl_na_count/t_count*100))
nul_na_count = t_count - user_profile_df.dropna(subset=["new_user_class_level"]).count()
print("new_user_class_level的空值情况:", nul_na_count, "空值占比:%0.2f%%"%(nul_na_count/t_count*100))

显示内容:

分类特征值个数情况: 
cms_segid:  97
cms_group_id:  13
final_gender_code:  2
age_level:  7
shopping_level:  3
occupation:  2
含缺失值的特征情况: 
+------------+------+
|pvalue_level| count|
+------------+------+
|        null|575917|
|           1|154436|
|           3| 37759|
|           2|293656|
+------------+------+

+--------------------+------+
|new_user_class_level| count|
+--------------------+------+
|                null|344920|
|                   1| 80548|
|                   3|173047|
|                   4|138833|
|                   2|324420|
+--------------------+------+

pvalue_level的空值情况: 575917 空值占比:54.24%
new_user_class_level的空值情况: 344920 空值占比:32.49%
  • 缺失值处理

    • 注意,一般情况下:

      • 缺失率低于10%:可直接进行相应的填充,如默认值、均值、算法拟合等等;
      • 高于10%:往往会考虑舍弃该特征
      • 特征处理,如1维转多维

      但根据我们的经验,我们的广告推荐其实和用户的消费水平、用户所在城市等级都有比较大的关联,因此在这里pvalue_level、new_user_class_level都是比较重要的特征,我们不考虑舍弃

  • 缺失值处理方案:

    • 填充方案:结合用户的其他特征值,利用随机森林算法进行预测;但产生了大量人为构建的数据,一定程度上增加了数据的噪音
    • 把变量映射到高维空间:如pvalue_level的1维数据,转换成是否1、是否2、是否3、是否缺失的4维数据;这样保证了所有原始数据不变,同时能提高精确度,但这样会导致数据变得比较稀疏,如果样本量很小,反而会导致样本效果较差,因此也不能滥用
  • 填充方案

    • 利用随机森林对pvalue_level的缺失值进行预测
from pyspark.mllib.regression import LabeledPoint

# 剔除掉缺失值数据,将余下的数据作为训练数据
# user_profile_df.dropna(subset=["pvalue_level"]): 将pvalue_level中的空值所在行数据剔除后的数据,作为训练样本
train_data = user_profile_df.dropna(subset=["pvalue_level"]).rdd.map(
    lambda r:LabeledPoint(r.pvalue_level-1, [r.cms_segid, r.cms_group_id, r.final_gender_code, r.age_level, r.shopping_level, r.occupation])
)

# 注意随机森林输入数据时,由于label的分类数是从0开始的,但pvalue_level的目前只分别是1,2,3,所以需要对应分别-1来作为目标值
# 自然那么最终得出预测值后,需要对应+1才能还原回来

# 我们使用cms_segid, cms_group_id, final_gender_code, age_level, shopping_level, occupation作为特征值,pvalue_level作为目标值
  • Labeled point

A labeled point is a local vector, either dense or sparse, associated with a label/response. In MLlib, labeled points are used in supervised learning algorithms. We use a double to store a label, so we can use labeled points in both regression and classification. For binary classification, a label should be either 0 (negative) or 1 (positive). For multiclass classification, labels should be class indices starting from zero: 0, 1, 2, ….
标记点是与标签/响应相关联的密集或稀疏的局部矢量。在MLlib中,标记点用于监督学习算法。我们使用double来存储标签,因此我们可以在回归和分类中使用标记点。对于二进制分类,标签应为0(负)或1(正)。对于多类分类,标签应该是从零开始的类索引:0, 1, 2, …。

Python
A labeled point is represented by LabeledPoint.
标记点表示为 LabeledPoint。
Refer to the LabeledPoint Python docs for more details on the API.
有关API的更多详细信息,请参阅LabeledPointPython文档。

from pyspark.mllib.linalg import SparseVector
from pyspark.mllib.regression import LabeledPoint

# Create a labeled point with a positive label and a dense feature vector.
pos = LabeledPoint(1.0, [1.0, 0.0, 3.0])

# Create a labeled point with a negative label and a sparse feature vector.
neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0]))
from pyspark.mllib.tree import RandomForest
# 训练分类模型
# 参数1 训练的数据
#参数2 目标值的分类个数 0,1,2
#参数3 特征中是否包含分类的特征 {2:2,3:7} {2:2} 表示 在特征中 第二个特征是分类的: 有两个分类
#参数4 随机森林中 树的棵数
model = RandomForest.trainClassifier(train_data, 3, {
    
    }, 5)
# 预测单个数据
# 注意用法:https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html?highlight=tree%20random#pyspark.mllib.tree.RandomForestModel.predict
model.predict([0.0, 4.0 ,2.0 , 4.0, 1.0, 0.0])

显示结果:

1.0
  • 筛选出缺失值条目
pl_na_df = user_profile_df.na.fill(-1).where("pvalue_level=-1")
pl_na_df.show(10)

def row(r):
    return r.cms_segid, r.cms_group_id, r.final_gender_code, r.age_level, r.shopping_level, r.occupation

# 转换为普通的rdd类型
rdd = pl_na_df.rdd.map(row)
# 预测全部的pvalue_level值:
predicts = model.predict(rdd)
# 查看前20条
print(predicts.take(20))
print("预测值总数", predicts.count())

# 这里注意predict参数,如果是预测多个,那么参数必须是直接有列表构成的rdd参数,而不能是dataframe.rdd类型
# 因此这里经过map函数处理,将每一行数据转换为普通的列表数据

显示结果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|
| 11602|        0|           5|                2|        5|          -1|             3|         0|                   2|
| 11727|        0|           3|                2|        3|          -1|             3|         0|                   1|
| 12195|        0|          10|                1|        4|          -1|             3|         0|                   2|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 10 rows

[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0]
预测值总数 575917
  • 转换为pandas dataframe
# 这里数据量比较小,直接转换为pandas dataframe来处理,因为方便,但注意如果数据量较大不推荐,因为这样会把全部数据加载到内存中
temp = predicts.map(lambda x:int(x)).collect()
pdf = pl_na_df.toPandas()
import numpy as np
 # 在pandas df的基础上直接替换掉列数据
pdf["pvalue_level"] = np.array(temp) + 1  # 注意+1 还原预测值
pdf
  • 与非缺失数据进行拼接,完成pvalue_level的缺失值预测
new_user_profile_df = user_profile_df.dropna(subset=["pvalue_level"]).unionAll(spark.createDataFrame(pdf, schema=schema))
new_user_profile_df.show()

# 注意:unionAll的使用,两个df的表结构必须完全一样

显示结果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|
|   612|        0|           8|                1|        2|           2|             3|         0|                null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                null|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|
| 11739|       20|           3|                2|        3|           2|             3|         0|                   4|
| 12549|       33|           4|                2|        4|           2|             3|         0|                   2|
| 15155|       36|           5|                2|        5|           2|             1|         0|                null|
| 15347|       20|           3|                2|        3|           2|             3|         0|                   3|
| 15455|        8|           2|                2|        2|           2|             3|         0|                   3|
| 15783|        0|           4|                2|        4|           2|             3|         0|                null|
| 16749|        5|           2|                2|        2|           1|             3|         1|                   4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 20 rows
  • 利用随机森林对new_user_class_level的缺失值进行预测
from pyspark.mllib.regression import LabeledPoint

# 选出new_user_class_level全部的
train_data2 = user_profile_df.dropna(subset=["new_user_class_level"]).rdd.map(
    lambda r:LabeledPoint(r.new_user_class_level - 1, [r.cms_segid, r.cms_group_id, r.final_gender_code, r.age_level, r.shopping_level, r.occupation])
)
from pyspark.mllib.tree import RandomForest
model2 = RandomForest.trainClassifier(train_data2, 4, {
    
    }, 5)
model2.predict([0.0, 4.0 ,2.0 , 4.0, 1.0, 0.0])
# 预测值实际应该为2

显示结果:

1.0
nul_na_df = user_profile_df.na.fill(-1).where("new_user_class_level=-1")
nul_na_df.show(10)

def row(r):
    return r.cms_segid, r.cms_group_id, r.final_gender_code, r.age_level, r.shopping_level, r.occupation

rdd2 = nul_na_df.rdd.map(row)
predicts2 = model.predict(rdd2)
predicts2.take(20)
  • 显示结果:
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 12620|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 14437|        0|           5|                2|        5|          -1|             3|         0|                  -1|
| 14574|        0|           1|                2|        1|          -1|             2|         0|                  -1|
| 14985|        0|          11|                1|        5|          -1|             2|         0|                  -1|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 10 rows

[1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0]

  • 总结:可以发现由于这两个字段的缺失过多,所以预测出来的值已经大大失真,但如果缺失率在10%以下,这种方法是比较有效的一种
user_profile_df = user_profile_df.na.fill(-1)
user_profile_df.show()
# new_df = new_df.withColumn("pvalue_level", new_df.pvalue_level.cast(StringType()))\
#     .withColumn("new_user_class_level", new_df.new_user_class_level.cast(StringType()))

显示结果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 20 rows
  • 低维转高维方式
    • 我们接下来采用将变量映射到高维空间的方法来处理数据,即将缺失项也当做一个单独的特征来对待,保证数据的原始性
      由于该思想正好和热独编码实现方法一样,因此这里直接使用热独编码方式处理数据
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# 使用热独编码转换pvalue_level的一维数据为多维,其中缺失值单独作为一个特征值

# 需要先将缺失值全部替换为数值,与原有特征一起处理
from pyspark.sql.types import StringType
user_profile_df = user_profile_df.na.fill(-1)
user_profile_df.show()

# 热独编码时,必须先将待处理字段转为字符串类型才可处理
user_profile_df = user_profile_df.withColumn("pvalue_level", user_profile_df.pvalue_level.cast(StringType()))\
    .withColumn("new_user_class_level", user_profile_df.new_user_class_level.cast(StringType()))
user_profile_df.printSchema()

# 对pvalue_level进行热独编码,求值
stringindexer = StringIndexer(inputCol='pvalue_level', outputCol='pl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pl_onehot_feature', outputCol='pl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(user_profile_df)
user_profile_df2 = pipeline_fit.transform(user_profile_df)
# pl_onehot_value列的值为稀疏向量,存储热独编码的结果
user_profile_df2.printSchema()
user_profile_df2.show()

显示结果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 20 rows

root
 |-- userId: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: string (nullable = true)

root
 |-- userId: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: string (nullable = true)
 |-- pl_onehot_feature: double (nullable = false)
 |-- pl_onehot_value: vector (nullable = true)

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|              0.0|  (4,[0],[1.0])|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|              2.0|  (4,[2],[1.0])|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|              0.0|  (4,[0],[1.0])|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|              0.0|  (4,[0],[1.0])|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|              0.0|  (4,[0],[1.0])|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|              1.0|  (4,[1],[1.0])|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|              1.0|  (4,[1],[1.0])|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|              1.0|  (4,[1],[1.0])|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|              0.0|  (4,[0],[1.0])|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|              2.0|  (4,[2],[1.0])|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+
only showing top 20 rows

  • 使用热编码转换new_user_class_level的一维数据为多维
stringindexer = StringIndexer(inputCol='new_user_class_level', outputCol='nucl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='nucl_onehot_feature', outputCol='nucl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(user_profile_df2)
user_profile_df3 = pipeline_fit.transform(user_profile_df2)
user_profile_df3.show()

显示结果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|nucl_onehot_feature|nucl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|              0.0|  (4,[0],[1.0])|                2.0|    (5,[2],[1.0])|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|              2.0|  (4,[2],[1.0])|                1.0|    (5,[1],[1.0])|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|              0.0|  (4,[0],[1.0])|                1.0|    (5,[1],[1.0])|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|              1.0|  (4,[1],[1.0])|                4.0|    (5,[4],[1.0])|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|              2.0|  (4,[2],[1.0])|                2.0|    (5,[2],[1.0])|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
only showing top 20 rows
  • 用户特征合并
from pyspark.ml.feature import VectorAssembler
feature_df = VectorAssembler().setInputCols(["age_level", "pl_onehot_value", "nucl_onehot_value"]).setOutputCol("features").transform(user_profile_df3)
feature_df.show()

显示结果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|nucl_onehot_feature|nucl_onehot_value|            features|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|              0.0|  (4,[0],[1.0])|                2.0|    (5,[2],[1.0])|(10,[0,1,7],[5.0,...|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|              2.0|  (4,[2],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,3,6],[2.0,...|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,2,5],[2.0,...|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,1,5],[4.0,...|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,1,5],[4.0,...|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,2,6],[6.0,...|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,2,6],[5.0,...|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|              0.0|  (4,[0],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,1,6],[3.0,...|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|(10,[0,3,8],[1.0,...|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|              1.0|  (4,[1],[1.0])|                4.0|    (5,[4],[1.0])|(10,[0,2,9],[5.0,...|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,2,6],[2.0,...|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|(10,[0,1,8],[5.0,...|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,2,6],[2.0,...|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,2,6],[4.0,...|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,2,5],[4.0,...|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,1,5],[4.0,...|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,2,5],[4.0,...|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|(10,[0,1,8],[5.0,...|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|              2.0|  (4,[2],[1.0])|                2.0|    (5,[2],[1.0])|(10,[0,3,7],[2.0,...|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|(10,[0,3,8],[4.0,...|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+--------------------+
only showing top 20 rows
feature_df.select("features").show()

显示结果:

+--------------------+
|            features|
+--------------------+
|(10,[0,1,7],[5.0,...|
|(10,[0,3,6],[2.0,...|
|(10,[0,2,5],[2.0,...|
|(10,[0,1,5],[4.0,...|
|(10,[0,1,5],[4.0,...|
|(10,[0,2,6],[6.0,...|
|(10,[0,2,6],[5.0,...|
|(10,[0,1,6],[3.0,...|
|(10,[0,3,8],[1.0,...|
|(10,[0,2,9],[5.0,...|
|(10,[0,2,6],[2.0,...|
|(10,[0,1,8],[5.0,...|
|(10,[0,2,6],[2.0,...|
|(10,[0,2,6],[4.0,...|
|(10,[0,2,5],[4.0,...|
|(10,[0,1,5],[4.0,...|
|(10,[0,2,5],[4.0,...|
|(10,[0,1,8],[5.0,...|
|(10,[0,3,7],[2.0,...|
|(10,[0,3,8],[4.0,...|
+--------------------+
only showing top 20 rows
  • 特征选取

除了前面处理的pvalue_level和new_user_class_level需要作为特征以外,(能体现出用户的购买力特征),还有:

前面分析的以下几个分类特征值个数情况:

- cms_segid:  97
- cms_group_id:  13
- final_gender_code:  2
- age_level:  7
- shopping_level:  3
- occupation:  2
-pvalue_level
-new_user_class_level
-price

根据经验,以上几个分类特征都一定程度能体现用户在购物方面的特征,且类别都较少,都可以用来作为用户特征

四 LR实现CTR预估

4.1 Spark逻辑回归(LR)模型使用介绍

from pyspark.ml.feature import VectorAssembler
import pandas as pd

# 样本数据集
sample_dataset = [
    (0, "male", 37, 10, "no", 3, 18, 7, 4),
    (0, "female", 27, 4, "no", 4, 14, 6, 4),
    (0, "female", 32, 15, "yes", 1, 12, 1, 4),
    (0, "male", 57, 15, "yes", 5, 18, 6, 5),
    (0, "male", 22, 0.75, "no", 2, 17, 6, 3),
    (0, "female", 32, 1.5, "no", 2, 17, 5, 5),
    (0, "female", 22, 0.75, "no", 2, 12, 1, 3),
    (0, "male", 57, 15, "yes", 2, 14, 4, 4),
    (0, "female", 32, 15, "yes", 4, 16, 1, 2),
    (0, "male", 22, 1.5, "no", 4, 14, 4, 5),
    (0, "male", 37, 15, "yes", 2, 20, 7, 2),
    (0, "male", 27, 4, "yes", 4, 18, 6, 4),
    (0, "male", 47, 15, "yes", 5, 17, 6, 4),
    (0, "female", 22, 1.5, "no", 2, 17, 5, 4),
    (0, "female", 27, 4, "no", 4, 14, 5, 4),
    (0, "female", 37, 15, "yes", 1, 17, 5, 5),
    (0, "female", 37, 15, "yes", 2, 18, 4, 3),
    (0, "female", 22, 0.75, "no", 3, 16, 5, 4),
    (0, "female", 22, 1.5, "no", 2, 16, 5, 5),
    (0, "female", 27, 10, "yes", 2, 14, 1, 5),
    (1, "female", 32, 15, "yes", 3, 14, 3, 2),
    (1, "female", 27, 7, "yes", 4, 16, 1, 2),
    (1, "male", 42, 15, "yes", 3, 18, 6, 2),
    (1, "female", 42, 15, "yes", 2, 14, 3, 2),
    (1, "male", 27, 7, "yes", 2, 17, 5, 4),
    (1, "male", 32, 10, "yes", 4, 14, 4, 3),
    (1, "male", 47, 15, "yes", 3, 16, 4, 2),
    (0, "male", 37, 4, "yes", 2, 20, 6, 4)
]

columns = ["affairs", "gender", "age", "label", "children", "religiousness", "education", "occupation", "rating"]

# pandas构建dataframe,方便
pdf = pd.DataFrame(sample_dataset, columns=columns)

# 转换成spark的dataframe
df = spark.createDataFrame(pdf)

# 特征选取:affairs为目标值,其余为特征值
df2 = df.select("affairs","age", "religiousness", "education", "occupation", "rating")

# 用于计算特征向量的字段
colArray2 = ["age", "religiousness", "education", "occupation", "rating"]

# 计算出特征向量
df3 = VectorAssembler().setInputCols(colArray2).setOutputCol("features").transform(df2)
print("数据集:")
df3.show()

#  随机切分为训练集和测试集
trainDF, testDF = df3.randomSplit([0.8,0.2])
print("训练集:")
trainDF.show(10)
print("测试集:")
testDF.show(10)

显示结果:

数据集:
+-------+---+-------------+---------+----------+------+--------------------+
|affairs|age|religiousness|education|occupation|rating|            features|
+-------+---+-------------+---------+----------+------+--------------------+
|      0| 37|            3|       18|         7|     4|[37.0,3.0,18.0,7....|
|      0| 27|            4|       14|         6|     4|[27.0,4.0,14.0,6....|
|      0| 32|            1|       12|         1|     4|[32.0,1.0,12.0,1....|
|      0| 57|            5|       18|         6|     5|[57.0,5.0,18.0,6....|
|      0| 22|            2|       17|         6|     3|[22.0,2.0,17.0,6....|
|      0| 32|            2|       17|         5|     5|[32.0,2.0,17.0,5....|
|      0| 22|            2|       12|         1|     3|[22.0,2.0,12.0,1....|
|      0| 57|            2|       14|         4|     4|[57.0,2.0,14.0,4....|
|      0| 32|            4|       16|         1|     2|[32.0,4.0,16.0,1....|
|      0| 22|            4|       14|         4|     5|[22.0,4.0,14.0,4....|
|      0| 37|            2|       20|         7|     2|[37.0,2.0,20.0,7....|
|      0| 27|            4|       18|         6|     4|[27.0,4.0,18.0,6....|
|      0| 47|            5|       17|         6|     4|[47.0,5.0,17.0,6....|
|      0| 22|            2|       17|         5|     4|[22.0,2.0,17.0,5....|
|      0| 27|            4|       14|         5|     4|[27.0,4.0,14.0,5....|
|      0| 37|            1|       17|         5|     5|[37.0,1.0,17.0,5....|
|      0| 37|            2|       18|         4|     3|[37.0,2.0,18.0,4....|
|      0| 22|            3|       16|         5|     4|[22.0,3.0,16.0,5....|
|      0| 22|            2|       16|         5|     5|[22.0,2.0,16.0,5....|
|      0| 27|            2|       14|         1|     5|[27.0,2.0,14.0,1....|
+-------+---+-------------+---------+----------+------+--------------------+
only showing top 20 rows

训练集:
+-------+---+-------------+---------+----------+------+--------------------+
|affairs|age|religiousness|education|occupation|rating|            features|
+-------+---+-------------+---------+----------+------+--------------------+
|      0| 32|            1|       12|         1|     4|[32.0,1.0,12.0,1....|
|      0| 37|            3|       18|         7|     4|[37.0,3.0,18.0,7....|
|      0| 22|            2|       17|         6|     3|[22.0,2.0,17.0,6....|
|      0| 32|            2|       17|         5|     5|[32.0,2.0,17.0,5....|
|      0| 57|            5|       18|         6|     5|[57.0,5.0,18.0,6....|
|      0| 57|            2|       14|         4|     4|[57.0,2.0,14.0,4....|
|      0| 22|            2|       17|         5|     4|[22.0,2.0,17.0,5....|
|      0| 22|            4|       14|         4|     5|[22.0,4.0,14.0,4....|
|      0| 27|            4|       18|         6|     4|[27.0,4.0,18.0,6....|
|      0| 37|            2|       20|         7|     2|[37.0,2.0,20.0,7....|
+-------+---+-------------+---------+----------+------+--------------------+
only showing top 10 rows

测试集:
+-------+---+-------------+---------+----------+------+--------------------+
|affairs|age|religiousness|education|occupation|rating|            features|
+-------+---+-------------+---------+----------+------+--------------------+
|      0| 27|            4|       14|         6|     4|[27.0,4.0,14.0,6....|
|      0| 22|            2|       12|         1|     3|[22.0,2.0,12.0,1....|
|      0| 32|            4|       16|         1|     2|[32.0,4.0,16.0,1....|
|      0| 27|            4|       14|         5|     4|[27.0,4.0,14.0,5....|
|      0| 22|            3|       16|         5|     4|[22.0,3.0,16.0,5....|
|      1| 27|            4|       16|         1|     2|[27.0,4.0,16.0,1....|
+-------+---+-------------+---------+----------+------+--------------------+
  • 逻辑回归训练模型
from pyspark.ml.classification import LogisticRegression
# 创建逻辑回归训练器
lr = LogisticRegression()
# 训练模型
model = lr.setLabelCol("affairs").setFeaturesCol("features").fit(trainDF)
# 预测数据
model.transform(testDF).show()

显示结果:

+-------+---+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+
|affairs|age|religiousness|education|occupation|rating|            features|       rawPrediction|         probability|prediction|
+-------+---+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+
|      0| 27|            4|       14|         6|     4|[27.0,4.0,14.0,6....|[0.39067871041193...|[0.59644607432863...|       0.0|
|      0| 22|            2|       12|         1|     3|[22.0,2.0,12.0,1....|[-2.6754687573263...|[0.06443650129497...|       1.0|
|      0| 32|            4|       16|         1|     2|[32.0,4.0,16.0,1....|[-4.5240336812732...|[0.01072883305878...|       1.0|
|      0| 27|            4|       14|         5|     4|[27.0,4.0,14.0,5....|[0.16206512668426...|[0.54042783360658...|       0.0|
|      0| 22|            3|       16|         5|     4|[22.0,3.0,16.0,5....|[1.69102697292197...|[0.84435916906682...|       0.0|
|      1| 27|            4|       16|         1|     2|[27.0,4.0,16.0,1....|[-4.7969907272012...|[0.00818697014985...|       1.0|
+-------+---+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

4.2 基于LR的点击率预测模型训练

  • 本小节主要根据广告点击样本数据集(raw_sample)、广告基本特征数据集(ad_feature)、用户基本信息数据集(user_profile)构建出了一个完整的样本数据集,并按日期划分为了训练集(前七天)和测试集(最后一天),利用逻辑回归进行训练。

    训练模型时,通过对类别特征数据进行处理,一定程度达到提高了模型的效果

'''从HDFS中加载样本数据信息'''
_raw_sample_df1 = spark.read.csv("hdfs://localhost:8020/csv/raw_sample.csv", header=True)
# _raw_sample_df1.show()    # 展示数据,默认前20条
# 更改表结构,转换为对应的数据类型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, LongType, StringType
  
# 更改df表结构:更改列类型和列名称
_raw_sample_df2 = _raw_sample_df1.\
    withColumn("user", _raw_sample_df1.user.cast(IntegerType())).withColumnRenamed("user", "userId").\
    withColumn("time_stamp", _raw_sample_df1.time_stamp.cast(LongType())).withColumnRenamed("time_stamp", "timestamp").\
    withColumn("adgroup_id", _raw_sample_df1.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("pid", _raw_sample_df1.pid.cast(StringType())).\
    withColumn("nonclk", _raw_sample_df1.nonclk.cast(IntegerType())).\
    withColumn("clk", _raw_sample_df1.clk.cast(IntegerType()))
_raw_sample_df2.printSchema()
_raw_sample_df2.show()

# 样本数据pid特征处理
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

stringindexer = StringIndexer(inputCol='pid', outputCol='pid_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pid_feature', outputCol='pid_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_raw_sample_df2)
raw_sample_df = pipeline_fit.transform(_raw_sample_df2)
raw_sample_df.show()

'''pid和特征的对应关系
430548_1007:0
430549_1007:1
'''

显示结果:

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)

+------+----------+---------+-----------+------+---+
|userId| timestamp|adgroupId|        pid|nonclk|clk|
+------+----------+---------+-----------+------+---+
|581738|1494137644|        1|430548_1007|     1|  0|
|449818|1494638778|        3|430548_1007|     1|  0|
|914836|1494650879|        4|430548_1007|     1|  0|
|914836|1494651029|        5|430548_1007|     1|  0|
|399907|1494302958|        8|430548_1007|     1|  0|
|628137|1494524935|        9|430548_1007|     1|  0|
|298139|1494462593|        9|430539_1007|     1|  0|
|775475|1494561036|        9|430548_1007|     1|  0|
|555266|1494307136|       11|430539_1007|     1|  0|
|117840|1494036743|       11|430548_1007|     1|  0|
|739815|1494115387|       11|430539_1007|     1|  0|
|623911|1494625301|       11|430548_1007|     1|  0|
|623911|1494451608|       11|430548_1007|     1|  0|
|421590|1494034144|       11|430548_1007|     1|  0|
|976358|1494156949|       13|430548_1007|     1|  0|
|286630|1494218579|       13|430539_1007|     1|  0|
|286630|1494289247|       13|430539_1007|     1|  0|
|771431|1494153867|       13|430548_1007|     1|  0|
|707120|1494220810|       13|430548_1007|     1|  0|
|530454|1494293746|       13|430548_1007|     1|  0|
+------+----------+---------+-----------+------+---+
only showing top 20 rows

+------+----------+---------+-----------+------+---+-----------+-------------+
|userId| timestamp|adgroupId|        pid|nonclk|clk|pid_feature|    pid_value|
+------+----------+---------+-----------+------+---+-----------+-------------+
|581738|1494137644|        1|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|449818|1494638778|        3|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494650879|        4|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494651029|        5|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|399907|1494302958|        8|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|628137|1494524935|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|298139|1494462593|        9|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|775475|1494561036|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|555266|1494307136|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|117840|1494036743|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|739815|1494115387|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|623911|1494625301|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|623911|1494451608|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|421590|1494034144|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|976358|1494156949|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|286630|1494218579|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|286630|1494289247|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|771431|1494153867|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|707120|1494220810|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|530454|1494293746|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
+------+----------+---------+-----------+------+---+-----------+-------------+
only showing top 20 rows

'pid和特征的对应关系\n430548_1007:0\n430549_1007:1\n'
  • 从HDFS中加载广告基本信息数据
_ad_feature_df = spark.read.csv("hdfs://localhost:9000/datasets/ad_feature.csv", header=True)

# 更改表结构,转换为对应的数据类型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替换掉NULL字符串
_ad_feature_df = _ad_feature_df.replace("NULL", "-1")
 
# 更改df表结构:更改列类型和列名称
ad_feature_df = _ad_feature_df.\
    withColumn("adgroup_id", _ad_feature_df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", _ad_feature_df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", _ad_feature_df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", _ad_feature_df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", _ad_feature_df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", _ad_feature_df.price.cast(FloatType()))
ad_feature_df.printSchema()
ad_feature_df.show()

显示结果:

root
 |-- adgroupId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- campaignId: integer (nullable = true)
 |-- customerId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
 |-- price: float (nullable = true)

+---------+------+----------+----------+-------+-----+
|adgroupId|cateId|campaignId|customerId|brandId|price|
+---------+------+----------+----------+-------+-----+
|    63133|  6406|     83237|         1|  95471|170.0|
|   313401|  6406|     83237|         1|  87331|199.0|
|   248909|   392|     83237|         1|  32233| 38.0|
|   208458|   392|     83237|         1| 174374|139.0|
|   110847|  7211|    135256|         2| 145952|32.99|
|   607788|  6261|    387991|         6| 207800|199.0|
|   375706|  4520|    387991|         6|     -1| 99.0|
|    11115|  7213|    139747|         9| 186847| 33.0|
|    24484|  7207|    139744|         9| 186847| 19.0|
|    28589|  5953|    395195|        13|     -1|428.0|
|    23236|  5953|    395195|        13|     -1|368.0|
|   300556|  5953|    395195|        13|     -1|639.0|
|    92560|  5953|    395195|        13|     -1|368.0|
|   590965|  4284|     28145|        14| 454237|249.0|
|   529913|  4284|     70206|        14|     -1|249.0|
|   546930|  4284|     28145|        14|     -1|249.0|
|   639794|  6261|     70206|        14|  37004| 89.9|
|   335413|  4284|     28145|        14|     -1|249.0|
|   794890|  4284|     70206|        14| 454237|249.0|
|   684020|  6261|     70206|        14|  37004| 99.0|
+---------+------+----------+----------+-------+-----+
only showing top 20 rows
  • 从HDFS加载用户基本信息数据
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType

# 构建表结构schema对象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("cms_segid", IntegerType()),
    StructField("cms_group_id", IntegerType()),
    StructField("final_gender_code", IntegerType()),
    StructField("age_level", IntegerType()),
    StructField("pvalue_level", IntegerType()),
    StructField("shopping_level", IntegerType()),
    StructField("occupation", IntegerType()),
    StructField("new_user_class_level", IntegerType())
])
# 利用schema从hdfs加载
_user_profile_df1 = spark.read.csv("hdfs://localhost:9000/datasets/user_profile.csv", header=True, schema=schema)
# user_profile_df.printSchema()
# user_profile_df.show()

'''对缺失数据进行特征热编码'''
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# 使用热编码转换pvalue_level的一维数据为多维,增加n-1个虚拟变量,n为pvalue_level的取值范围

# 需要先将缺失值全部替换为数值,便于处理,否则会抛出异常
from pyspark.sql.types import StringType
_user_profile_df2 = _user_profile_df1.na.fill(-1)
# _user_profile_df2.show()

# 热编码时,必须先将待处理字段转为字符串类型才可处理
_user_profile_df3 = _user_profile_df2.withColumn("pvalue_level", _user_profile_df2.pvalue_level.cast(StringType()))\
    .withColumn("new_user_class_level", _user_profile_df2.new_user_class_level.cast(StringType()))
# _user_profile_df3.printSchema()

# 对pvalue_level进行热编码,求值
# 运行过程是先将pvalue_level转换为一列新的特征数据,然后对该特征数据求出的热编码值,存在了新的一列数据中,类型为一个稀疏矩阵
stringindexer = StringIndexer(inputCol='pvalue_level', outputCol='pl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pl_onehot_feature', outputCol='pl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_user_profile_df3)
_user_profile_df4 = pipeline_fit.transform(_user_profile_df3)
# pl_onehot_value列的值为稀疏矩阵,存储热编码的结果
# _user_profile_df4.printSchema()
# _user_profile_df4.show()

# 使用热编码转换new_user_class_level的一维数据为多维
stringindexer = StringIndexer(inputCol='new_user_class_level', outputCol='nucl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='nucl_onehot_feature', outputCol='nucl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_user_profile_df4)
user_profile_df = pipeline_fit.transform(_user_profile_df4)
user_profile_df.show()

显示结果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|nucl_onehot_feature|nucl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|              0.0|  (4,[0],[1.0])|                2.0|    (5,[2],[1.0])|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|              2.0|  (4,[2],[1.0])|                1.0|    (5,[1],[1.0])|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|              0.0|  (4,[0],[1.0])|                1.0|    (5,[1],[1.0])|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|              1.0|  (4,[1],[1.0])|                4.0|    (5,[4],[1.0])|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|              2.0|  (4,[2],[1.0])|                2.0|    (5,[2],[1.0])|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
only showing top 20 rows

  • 热编码中:"pvalue_level"特征对应关系:
+------------+----------------------+
|pvalue_level|pl_onehot_feature     |
+------------+----------------------+
|          -1|                   0.0|
|           3|                   3.0|
|           1|                   2.0|
|           2|                   1.0|
+------------+----------------------+
  • “new_user_class_level”的特征对应关系
+--------------------+------------------------+
|new_user_class_level|nucl_onehot_feature     |
+--------------------+------------------------+
|                  -1|                     0.0|
|                   3|                     2.0|
|                   1|                     4.0|
|                   4|                     3.0|
|                   2|                     1.0|
+--------------------+------------------------+
user_profile_df.groupBy("pvalue_level").min("pl_onehot_feature").show()
user_profile_df.groupBy("new_user_class_level").min("nucl_onehot_feature").show()

显示结果:

+------------+----------------------+
|pvalue_level|min(pl_onehot_feature)|
+------------+----------------------+
|          -1|                   0.0|
|           3|                   3.0|
|           1|                   2.0|
|           2|                   1.0|
+------------+----------------------+

+--------------------+------------------------+
|new_user_class_level|min(nucl_onehot_feature)|
+--------------------+------------------------+
|                  -1|                     0.0|
|                   3|                     2.0|
|                   1|                     4.0|
|                   4|                     3.0|
|                   2|                     1.0|
+--------------------+------------------------+

# raw_sample_df和ad_feature_df合并条件
condition = [raw_sample_df.adgroupId==ad_feature_df.adgroupId]
_ = raw_sample_df.join(ad_feature_df, condition, 'outer')

# _和user_profile_df合并条件
condition2 = [_.userId==user_profile_df.userId]
datasets = _.join(user_profile_df, condition2, "outer")
# 查看datasets的结构
datasets.printSchema()
# 查看datasets条目数
print(datasets.count())

显示结果:

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)
 |-- pid_feature: double (nullable = true)
 |-- pid_value: vector (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- campaignId: integer (nullable = true)
 |-- customerId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
 |-- price: float (nullable = true)
 |-- userId: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: string (nullable = true)
 |-- pl_onehot_feature: double (nullable = true)
 |-- pl_onehot_value: vector (nullable = true)
 |-- nucl_onehot_feature: double (nullable = true)
 |-- nucl_onehot_value: vector (nullable = true)

26557961
  • 训练CTRModel_Normal:直接将对应的特征的特征值组合成对应的特征向量进行训练
# 剔除冗余、不需要的字段
useful_cols = [
    # 
    # 时间字段,划分训练集和测试集
    "timestamp",
    # label目标值字段
    "clk",  
    # 特征值字段
    "pid_value",       # 资源位的特征向量
    "price",    # 广告价格
    "cms_segid",    # 用户微群ID
    "cms_group_id",    # 用户组ID
    "final_gender_code",    # 用户性别特征,[1,2]
    "age_level",    # 年龄等级,1-
    "shopping_level",
    "occupation",
    "pl_onehot_value",
    "nucl_onehot_value"
]
# 筛选指定字段数据,构建新的数据集
datasets_1 = datasets.select(*useful_cols)
# 由于前面使用的是outer方式合并的数据,产生了部分空值数据,这里必须先剔除掉
datasets_1 = datasets_1.dropna()
print("剔除空值数据后,还剩:", datasets_1.count())

显示结果:

剔除空值数据后,还剩: 25029435

  • 根据特征字段计算出特征向量,并划分出训练数据集和测试数据集
from pyspark.ml.feature import VectorAssembler
# 根据特征字段计算特征向量
datasets_1 = VectorAssembler().setInputCols(useful_cols[2:]).setOutputCol("features").transform(datasets_1)
# 训练数据集: 约7天的数据
train_datasets_1 = datasets_1.filter(datasets_1.timestamp<=(1494691186-24*60*60))
# 测试数据集:约1天的数据量
test_datasets_1 = datasets_1.filter(datasets_1.timestamp>(1494691186-24*60*60))
# 所有的特征的特征向量已经汇总到在features字段中
train_datasets_1.show(5)
test_datasets_1.show(5)

显示结果:

+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
| timestamp|clk|    pid_value| price|cms_segid|cms_group_id|final_gender_code|age_level|shopping_level|occupation|pl_onehot_value|nucl_onehot_value|            features|
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
|1494261938|  0|(2,[1],[1.0])| 108.0|        0|          11|                1|        5|             3|         0|  (4,[0],[1.0])|    (5,[1],[1.0])|(18,[1,2,4,5,6,7,...|
|1494261938|  0|(2,[1],[1.0])|1880.0|        0|          11|                1|        5|             3|         0|  (4,[0],[1.0])|    (5,[1],[1.0])|(18,[1,2,4,5,6,7,...|
|1494553913|  0|(2,[1],[1.0])|2360.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494553913|  0|(2,[1],[1.0])|2200.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494436784|  0|(2,[1],[1.0])|5649.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
only showing top 5 rows

+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
| timestamp|clk|    pid_value|price|cms_segid|cms_group_id|final_gender_code|age_level|shopping_level|occupation|pl_onehot_value|nucl_onehot_value|            features|
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
|1494677292|  0|(2,[1],[1.0])|176.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494677292|  0|(2,[1],[1.0])|698.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494677292|  0|(2,[1],[1.0])|697.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494684007|  0|(2,[1],[1.0])|247.0|       18|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[4],[1.0])|(18,[1,2,3,4,5,6,...|
|1494684007|  0|(2,[1],[1.0])|109.0|       18|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[4],[1.0])|(18,[1,2,3,4,5,6,...|
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
only showing top 5 rows

from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression()
# 设置目标字段、特征值字段并训练
model = lr.setLabelCol("clk").setFeaturesCol("features").fit(train_datasets_1)
# 对模型进行存储
model.save("hdfs://localhost:9000/models/CTRModel_Normal.obj")
# 载入训练好的模型
from pyspark.ml.classification import LogisticRegressionModel
model = LogisticRegressionModel.load("hdfs://localhost:9000/models/CTRModel_Normal.obj")
# 根据测试数据进行预测
result_1 = model.transform(test_datasets_1)
# 按probability升序排列数据,probability表示预测结果的概率
# 如果预测值是0,其概率是0.9248,那么反之可推出1的可能性就是1-0.9248=0.0752,即点击概率约为7.52%
# 因为前面提到广告的点击率一般都比较低,所以预测值通常都是0,因此通常需要反减得出点击的概率
result_1.select("clk", "price", "probability", "prediction").sort("probability").show(100)

显示结果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  0|      1.0E8|[0.86822033939259...|       0.0|
|  0|      1.0E8|[0.88410457194969...|       0.0|
|  0|      1.0E8|[0.89175497837562...|       0.0|
|  1|5.5555556E7|[0.92481456486873...|       0.0|
|  0|      1.5E7|[0.93741450446939...|       0.0|
|  0|      1.5E7|[0.93757135079959...|       0.0|
|  0|      1.5E7|[0.93834723093801...|       0.0|
|  0|     1099.0|[0.93972095713786...|       0.0|
|  0|      338.0|[0.93972134993018...|       0.0|
|  0|      311.0|[0.93972136386626...|       0.0|
|  0|      300.0|[0.93972136954393...|       0.0|
|  0|      278.0|[0.93972138089925...|       0.0|
|  0|      188.0|[0.93972142735283...|       0.0|
|  0|      176.0|[0.93972143354663...|       0.0|
|  0|      168.0|[0.93972143767584...|       0.0|
|  0|      158.0|[0.93972144283734...|       0.0|
|  1|      138.0|[0.93972145316035...|       0.0|
|  0|      125.0|[0.93972145987031...|       0.0|
|  0|      119.0|[0.93972146296721...|       0.0|
|  0|       78.0|[0.93972148412937...|       0.0|
|  0|      59.98|[0.93972149343040...|       0.0|
|  0|       58.0|[0.93972149445238...|       0.0|
|  0|       56.0|[0.93972149548468...|       0.0|
|  0|       38.0|[0.93972150477538...|       0.0|
|  1|       35.0|[0.93972150632383...|       0.0|
|  0|       33.0|[0.93972150735613...|       0.0|
|  0|       30.0|[0.93972150890458...|       0.0|
|  0|       27.6|[0.93972151014334...|       0.0|
|  0|       18.0|[0.93972151509838...|       0.0|
|  0|       30.0|[0.93980311191464...|       0.0|
|  0|       28.0|[0.93980311294563...|       0.0|
|  0|       25.0|[0.93980311449212...|       0.0|
|  0|      688.0|[0.93999362023323...|       0.0|
|  0|      339.0|[0.93999379960808...|       0.0|
|  0|      335.0|[0.93999380166395...|       0.0|
|  0|      220.0|[0.93999386077017...|       0.0|
|  0|      176.0|[0.93999388338470...|       0.0|
|  0|      158.0|[0.93999389263610...|       0.0|
|  0|      158.0|[0.93999389263610...|       0.0|
|  1|      149.0|[0.93999389726180...|       0.0|
|  0|      122.5|[0.93999391088191...|       0.0|
|  0|       99.0|[0.93999392296012...|       0.0|
|  0|       88.0|[0.93999392861375...|       0.0|
|  0|       79.0|[0.93999393323945...|       0.0|
|  0|       75.0|[0.93999393529532...|       0.0|
|  0|       68.0|[0.93999393889308...|       0.0|
|  0|       68.0|[0.93999393889308...|       0.0|
|  0|       59.9|[0.93999394305620...|       0.0|
|  0|      44.98|[0.93999395072458...|       0.0|
|  0|       35.5|[0.93999395559698...|       0.0|
|  0|       33.0|[0.93999395688189...|       0.0|
|  0|       32.8|[0.93999395698469...|       0.0|
|  0|       30.0|[0.93999395842379...|       0.0|
|  0|       28.0|[0.93999395945172...|       0.0|
|  0|       19.9|[0.93999396361485...|       0.0|
|  0|       19.8|[0.93999396366625...|       0.0|
|  0|       19.8|[0.93999396366625...|       0.0|
|  0|       12.0|[0.93999396767518...|       0.0|
|  0|        6.7|[0.93999397039920...|       0.0|
|  0|      568.0|[0.94000369247841...|       0.0|
|  0|      398.0|[0.94000377983931...|       0.0|
|  0|      158.0|[0.94000390317214...|       0.0|
|  0|     5718.0|[0.94001886593718...|       0.0|
|  0|     5718.0|[0.94001886593718...|       0.0|
|  1|     5608.0|[0.94001892245145...|       0.0|
|  0|     4120.0|[0.94001968693052...|       0.0|
|  0|     1027.5|[0.94002127571285...|       0.0|
|  0|     1027.5|[0.94002127571285...|       0.0|
|  0|      989.0|[0.94002129549211...|       0.0|
|  0|      672.0|[0.94002145834965...|       0.0|
|  0|      660.0|[0.94002146451460...|       0.0|
|  0|      598.0|[0.94002149636681...|       0.0|
|  0|      598.0|[0.94002149636681...|       0.0|
|  0|      563.0|[0.94002151434789...|       0.0|
|  0|      509.0|[0.94002154209012...|       0.0|
|  0|      509.0|[0.94002154209012...|       0.0|
|  0|      500.0|[0.94002154671382...|       0.0|
|  0|      498.0|[0.94002154774131...|       0.0|
|  0|      440.0|[0.94002157753851...|       0.0|
|  0|      430.0|[0.94002158267595...|       0.0|
|  0|      388.0|[0.94002160425322...|       0.0|
|  0|      369.0|[0.94002161401436...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      366.0|[0.94002161555560...|       0.0|
|  0|      366.0|[0.94002161555560...|       0.0|
|  0|      348.0|[0.94002162480299...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      298.0|[0.94002165049020...|       0.0|
|  0|      297.0|[0.94002165100394...|       0.0|
|  0|      278.0|[0.94002166076508...|       0.0|
|  1|      275.0|[0.94002166230631...|       0.0|
|  0|      275.0|[0.94002166230631...|       0.0|
|  0|      273.0|[0.94002166333380...|       0.0|
|  0|      258.0|[0.94002167103995...|       0.0|
|  0|      256.0|[0.94002167206744...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows
  • 查看样本中点击的被实际点击的条目的预测情况
result_1.filter(result_1.clk==1).select("clk", "price", "probability", "prediction").sort("probability").show(100)

显示结果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  1|5.5555556E7|[0.92481456486873...|       0.0|
|  1|      138.0|[0.93972145316035...|       0.0|
|  1|       35.0|[0.93972150632383...|       0.0|
|  1|      149.0|[0.93999389726180...|       0.0|
|  1|     5608.0|[0.94001892245145...|       0.0|
|  1|      275.0|[0.94002166230631...|       0.0|
|  1|       35.0|[0.94002178560473...|       0.0|
|  1|       49.0|[0.94004219516957...|       0.0|
|  1|      915.0|[0.94021082858784...|       0.0|
|  1|      598.0|[0.94021099096349...|       0.0|
|  1|      568.0|[0.94021100633025...|       0.0|
|  1|      398.0|[0.94021109340848...|       0.0|
|  1|      368.0|[0.94021110877521...|       0.0|
|  1|      299.0|[0.94021114411869...|       0.0|
|  1|      278.0|[0.94021115487539...|       0.0|
|  1|      259.0|[0.94021116460765...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      195.0|[0.94021119738998...|       0.0|
|  1|      188.0|[0.94021120097554...|       0.0|
|  1|      178.0|[0.94021120609778...|       0.0|
|  1|      159.0|[0.94021121583003...|       0.0|
|  1|      149.0|[0.94021122095226...|       0.0|
|  1|      138.0|[0.94021122658672...|       0.0|
|  1|       58.0|[0.94021126756458...|       0.0|
|  1|       49.0|[0.94021127217459...|       0.0|
|  1|       35.0|[0.94021127934572...|       0.0|
|  1|       25.0|[0.94021128446795...|       0.0|
|  1|     2890.0|[0.94028789742257...|       0.0|
|  1|      220.0|[0.94028926340218...|       0.0|
|  1|      188.0|[0.94031410659516...|       0.0|
|  1|       68.0|[0.94031416796289...|       0.0|
|  1|       58.0|[0.94031417307687...|       0.0|
|  1|      198.0|[0.94035413548387...|       0.0|
|  1|      208.0|[0.94039204931181...|       0.0|
|  1|     8888.0|[0.94045237642030...|       0.0|
|  1|      519.0|[0.94045664687995...|       0.0|
|  1|      478.0|[0.94045666780037...|       0.0|
|  1|      349.0|[0.94045673362308...|       0.0|
|  1|      348.0|[0.94045673413334...|       0.0|
|  1|      316.0|[0.94045675046144...|       0.0|
|  1|      298.0|[0.94045675964600...|       0.0|
|  1|      298.0|[0.94045675964600...|       0.0|
|  1|      199.0|[0.94045681016104...|       0.0|
|  1|      199.0|[0.94045681016104...|       0.0|
|  1|      198.0|[0.94045681067129...|       0.0|
|  1|      187.1|[0.94045681623305...|       0.0|
|  1|      176.0|[0.94045682189685...|       0.0|
|  1|      168.0|[0.94045682597887...|       0.0|
|  1|      160.0|[0.94045683006090...|       0.0|
|  1|      158.0|[0.94045683108140...|       0.0|
|  1|      158.0|[0.94045683108140...|       0.0|
|  1|      135.0|[0.94045684281721...|       0.0|
|  1|      129.0|[0.94045684587872...|       0.0|
|  1|      127.0|[0.94045684689923...|       0.0|
|  1|      125.0|[0.94045684791973...|       0.0|
|  1|      124.0|[0.94045684842999...|       0.0|
|  1|      118.0|[0.94045685149150...|       0.0|
|  1|      109.0|[0.94045685608377...|       0.0|
|  1|      108.0|[0.94045685659402...|       0.0|
|  1|       99.0|[0.94045686118630...|       0.0|
|  1|       98.0|[0.94045686169655...|       0.0|
|  1|       79.8|[0.94045687098314...|       0.0|
|  1|       79.0|[0.94045687139134...|       0.0|
|  1|       77.0|[0.94045687241185...|       0.0|
|  1|       72.5|[0.94045687470798...|       0.0|
|  1|       69.0|[0.94045687649386...|       0.0|
|  1|       68.0|[0.94045687700412...|       0.0|
|  1|       60.0|[0.94045688108613...|       0.0|
|  1|      43.98|[0.94045688926037...|       0.0|
|  1|       40.0|[0.94045689129118...|       0.0|
|  1|       39.9|[0.94045689134220...|       0.0|
|  1|       39.6|[0.94045689149528...|       0.0|
|  1|       32.0|[0.94045689537319...|       0.0|
|  1|       31.0|[0.94045689588345...|       0.0|
|  1|      25.98|[0.94045689844491...|       0.0|
|  1|       23.0|[0.94045689996546...|       0.0|
|  1|       19.0|[0.94045690200647...|       0.0|
|  1|       16.9|[0.94045690307800...|       0.0|
|  1|       10.0|[0.94045690659874...|       0.0|
|  1|        3.5|[0.94045690991538...|       0.0|
|  1|        3.5|[0.94045690991538...|       0.0|
|  1|        0.4|[0.94045691149716...|       0.0|
|  1|     3960.0|[0.94055740378069...|       0.0|
|  1|     3088.0|[0.94055784801535...|       0.0|
|  1|     1689.0|[0.94055856072019...|       0.0|
|  1|      998.0|[0.94055891273943...|       0.0|
|  1|      888.0|[0.94055896877705...|       0.0|
|  1|      788.0|[0.94055901972029...|       0.0|
|  1|      737.0|[0.94055904570133...|       0.0|
|  1|      629.0|[0.94055910071996...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      499.0|[0.94055916694603...|       0.0|
|  1|      468.0|[0.94055918273839...|       0.0|
|  1|      459.0|[0.94055918732327...|       0.0|
|  1|      399.0|[0.94055921788912...|       0.0|
|  1|      399.0|[0.94055921788912...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows

  • 训练CTRModel_AllOneHot

    • “pid_value”, 类别型特征,已被转换为多维特征==> 2维
    • “price”, 统计型特征 ===> 1维
    • “cms_segid”, 类别型特征,约97个分类 ===> 1维
    • “cms_group_id”, 类别型特征,约13个分类 ==> 1维
    • “final_gender_code”, 类别型特征,2个分类 ==> 1维
    • “age_level”, 类别型特征,7个分类 ==> 1维
    • “shopping_level”, 类别型特征,3个分类 ==> 1维
    • “occupation”, 类别型特征,2个分类 ==> 1维
    • “pl_onehot_value”, 类别型特征,已被转换为多维特征 ==> 4维
    • “nucl_onehot_value” 类别型特征,已被转换为多维特征 ==> 5维

    类别性特征都可以考虑进行热独编码,将单一变量变为多变量,相当于增加了相关特征的数量

    • “cms_segid”, 类别型特征,约97个分类 ===> 97维 舍弃
    • “cms_group_id”, 类别型特征,约13个分类 ==> 13维
    • “final_gender_code”, 类别型特征,2个分类 ==> 2维
    • “age_level”, 类别型特征,7个分类 ==>7维
    • “shopping_level”, 类别型特征,3个分类 ==> 3维
    • “occupation”, 类别型特征,2个分类 ==> 2维

    但由于cms_segid分类过多,这里考虑舍弃,避免数据过于稀疏

datasets_1.first()

显示结果:

datasets_1.first()
datasets_1.first()
Row(timestamp=1494261938, clk=0, pid_value=SparseVector(2, {
    
    1: 1.0}), price=1880.0, cms_segid=0, cms_group_id=11, final_gender_code=1, age_level=5, shopping_level=3, occupation=0, pl_onehot_value=SparseVector(4, {
    
    0: 1.0}), nucl_onehot_value=SparseVector(5, {
    
    1: 1.0}), features=SparseVector(18, {
    
    1: 1.0, 2: 1880.0, 4: 11.0, 5: 1.0, 6: 5.0, 7: 3.0, 9: 1.0, 14: 1.0}))
# 先将下列五列数据转为字符串类型,以便于进行热独编码
# - "cms_group_id",   类别型特征,约13个分类 ==> 13
# - "final_gender_code", 类别型特征,2个分类 ==> 2
# - "age_level",    类别型特征,7个分类 ==>7
# - "shopping_level",    类别型特征,3个分类 ==> 3
# - "occupation",    类别型特征,2个分类 ==> 2

datasets_2 = datasets.withColumn("cms_group_id", datasets.cms_group_id.cast(StringType()))\
    .withColumn("final_gender_code", datasets.final_gender_code.cast(StringType()))\
    .withColumn("age_level", datasets.age_level.cast(StringType()))\
    .withColumn("shopping_level", datasets.shopping_level.cast(StringType()))\
    .withColumn("occupation", datasets.occupation.cast(StringType()))
useful_cols_2 = [
    # 时间值,划分训练集和测试集
    "timestamp",
    # label目标值
    "clk",  
    # 特征值
    "price",
    "cms_group_id",
    "final_gender_code",
    "age_level",
    "shopping_level",
    "occupation",
    "pid_value", 
    "pl_onehot_value",
    "nucl_onehot_value"
]
# 筛选指定字段数据
datasets_2 = datasets_2.select(*useful_cols_2)
# 由于前面使用的是outer方式合并的数据,产生了部分空值数据,这里必须先剔除掉
datasets_2 = datasets_2.dropna()


from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
# 热编码处理函数封装
def oneHotEncoder(col1, col2, col3, data):
    stringindexer = StringIndexer(inputCol=col1, outputCol=col2)
    encoder = OneHotEncoder(dropLast=False, inputCol=col2, outputCol=col3)
    pipeline = Pipeline(stages=[stringindexer, encoder])
    pipeline_fit = pipeline.fit(data)
    return pipeline_fit.transform(data)

# 对这五个字段进行热独编码
#     "cms_group_id",
#     "final_gender_code",
#     "age_level",
#     "shopping_level",
#     "occupation",
datasets_2 = oneHotEncoder("cms_group_id", "cms_group_id_feature", "cms_group_id_value", datasets_2)
datasets_2 = oneHotEncoder("final_gender_code", "final_gender_code_feature", "final_gender_code_value", datasets_2)
datasets_2 = oneHotEncoder("age_level", "age_level_feature", "age_level_value", datasets_2)
datasets_2 = oneHotEncoder("shopping_level", "shopping_level_feature", "shopping_level_value", datasets_2)
datasets_2 = oneHotEncoder("occupation", "occupation_feature", "occupation_value", datasets_2)
  • "cms_group_id"特征对应关系:
+------------+-------------------------+
|cms_group_id|min(cms_group_id_feature)|
+------------+-------------------------+
|           7|                      9.0|
|          11|                      6.0|
|           3|                      0.0|
|           8|                      8.0|
|           0|                     12.0|
|           5|                      3.0|
|           6|                     10.0|
|           9|                      5.0|
|           1|                      7.0|
|          10|                      4.0|
|           4|                      1.0|
|          12|                     11.0|
|           2|                      2.0|
+------------+-------------------------+

  • "final_gender_code"特征对应关系:
+-----------------+------------------------------+
|final_gender_code|min(final_gender_code_feature)|
+-----------------+------------------------------+
|                1|                           1.0|
|                2|                           0.0|
+-----------------+------------------------------+

  • "age_level"特征对应关系:
+---------+----------------------+
|age_level|min(age_level_feature)|
+---------+----------------------+
|        3|                   0.0|
|        0|                   6.0|
|        5|                   2.0|
|        6|                   5.0|
|        1|                   4.0|
|        4|                   1.0|
|        2|                   3.0|
+---------+----------------------+

  • "shopping_level"特征对应关系:
|shopping_level|min(shopping_level_feature)|
+--------------+---------------------------+
|             3|                        0.0|
|             1|                        2.0|
|             2|                        1.0|
+--------------+---------------------------+

  • "occupation"特征对应关系:
+----------+-----------------------+
|occupation|min(occupation_feature)|
+----------+-----------------------+
|         0|                    0.0|
|         1|                    1.0|
+----------+-----------------------+

datasets_2.groupBy("cms_group_id").min("cms_group_id_feature").show()
datasets_2.groupBy("final_gender_code").min("final_gender_code_feature").show()
datasets_2.groupBy("age_level").min("age_level_feature").show()
datasets_2.groupBy("shopping_level").min("shopping_level_feature").show()
datasets_2.groupBy("occupation").min("occupation_feature").show()

显示结果:

+------------+-------------------------+
|cms_group_id|min(cms_group_id_feature)|
+------------+-------------------------+
|           7|                      9.0|
|          11|                      6.0|
|           3|                      0.0|
|           8|                      8.0|
|           0|                     12.0|
|           5|                      3.0|
|           6|                     10.0|
|           9|                      5.0|
|           1|                      7.0|
|          10|                      4.0|
|           4|                      1.0|
|          12|                     11.0|
|           2|                      2.0|
+------------+-------------------------+

+-----------------+------------------------------+
|final_gender_code|min(final_gender_code_feature)|
+-----------------+------------------------------+
|                1|                           1.0|
|                2|                           0.0|
+-----------------+------------------------------+

+---------+----------------------+
|age_level|min(age_level_feature)|
+---------+----------------------+
|        3|                   0.0|
|        0|                   6.0|
|        5|                   2.0|
|        6|                   5.0|
|        1|                   4.0|
|        4|                   1.0|
|        2|                   3.0|
+---------+----------------------+

+--------------+---------------------------+
|shopping_level|min(shopping_level_feature)|
+--------------+---------------------------+
|             3|                        0.0|
|             1|                        2.0|
|             2|                        1.0|
+--------------+---------------------------+

+----------+-----------------------+
|occupation|min(occupation_feature)|
+----------+-----------------------+
|         0|                    0.0|
|         1|                    1.0|
+----------+-----------------------+

# 由于热独编码后,特征字段不再是之前的字段,重新定义特征值字段
feature_cols = [
    # 特征值
    "price",
    "cms_group_id_value",
    "final_gender_code_value",
    "age_level_value",
    "shopping_level_value",
    "occupation_value",
    "pid_value",
    "pl_onehot_value",
    "nucl_onehot_value"
]
# 根据特征字段计算出特征向量,并划分出训练数据集和测试数据集
from pyspark.ml.feature import VectorAssembler
datasets_2 = VectorAssembler().setInputCols(feature_cols).setOutputCol("features").transform(datasets_2)
train_datasets_2 = datasets_2.filter(datasets_2.timestamp<=(1494691186-24*60*60))
test_datasets_2 = datasets_2.filter(datasets_2.timestamp>(1494691186-24*60*60))
train_datasets_2.printSchema()
train_datasets_2.first()

显示结果:

root
 |-- timestamp: long (nullable = true)
 |-- clk: integer (nullable = true)
 |-- price: float (nullable = true)
 |-- cms_group_id: string (nullable = true)
 |-- final_gender_code: string (nullable = true)
 |-- age_level: string (nullable = true)
 |-- shopping_level: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- pid_value: vector (nullable = true)
 |-- pl_onehot_value: vector (nullable = true)
 |-- nucl_onehot_value: vector (nullable = true)
 |-- cms_group_id_feature: double (nullable = false)
 |-- cms_group_id_value: vector (nullable = true)
 |-- final_gender_code_feature: double (nullable = false)
 |-- final_gender_code_value: vector (nullable = true)
 |-- age_level_feature: double (nullable = false)
 |-- age_level_value: vector (nullable = true)
 |-- shopping_level_feature: double (nullable = false)
 |-- shopping_level_value: vector (nullable = true)
 |-- occupation_feature: double (nullable = false)
 |-- occupation_value: vector (nullable = true)
 |-- features: vector (nullable = true)

Row(timestamp=1494261938, clk=0, price=108.0, cms_group_id='11', final_gender_code='1', age_level='5', shopping_level='3', occupation='0', pid_value=SparseVector(2, {1: 1.0}), pl_onehot_value=SparseVector(4, {0: 1.0}), nucl_onehot_value=SparseVector(5, {1: 1.0}), cms_group_id_feature=6.0, cms_group_id_value=SparseVector(13, {6: 1.0}), final_gender_code_feature=1.0, final_gender_code_value=SparseVector(2, {1: 1.0}), age_level_feature=2.0, age_level_value=SparseVector(7, {2: 1.0}), shopping_level_feature=0.0, shopping_level_value=SparseVector(3, {0: 1.0}), occupation_feature=0.0, occupation_value=SparseVector(2, {0: 1.0}), features=SparseVector(39, {0: 108.0, 7: 1.0, 15: 1.0, 18: 1.0, 23: 1.0, 26: 1.0, 29: 1.0, 30: 1.0, 35: 1.0}))

  • 创建逻辑回归训练器,并训练模型
from pyspark.ml.classification import LogisticRegression
lr2 = LogisticRegression()
#设置目标值对应的列   setFeaturesCol 设置特征值对应的列名
model2 = lr2.setLabelCol("clk").setFeaturesCol("features").fit(train_datasets_2)
# 存储模型
model2.save("hdfs://localhost:9000/models/CTRModel_AllOneHot.obj")
from pyspark.ml.classification import LogisticRegressionModel
# 载入训练好的模型
model2 = LogisticRegressionModel.load("hdfs://localhost:9000/models/CTRModel_AllOneHot.obj")
result_2 = model2.transform(test_datasets_2)
# 按probability升序排列数据,probability表示预测结果的概率
result_2.select("clk", "price", "probability", "prediction").sort("probability").show(100)

# 对比前面的result_1的预测结果,能发现这里的预测率稍微准确了一点,这里top20里出现了3个点击的,但前面的只出现了1个
# 因此可见对特征的细化处理,已经帮助我们提高模型的效果的

显示结果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  0|      1.0E8|[0.85524418892857...|       0.0|
|  0|      1.0E8|[0.88353143762124...|       0.0|
|  0|      1.0E8|[0.89169808985616...|       0.0|
|  1|5.5555556E7|[0.92511743960350...|       0.0|
|  0|     179.01|[0.93239951738307...|       0.0|
|  1|      159.0|[0.93239952905659...|       0.0|
|  0|      118.0|[0.93239955297535...|       0.0|
|  0|      688.0|[0.93451506165953...|       0.0|
|  0|      339.0|[0.93451525933626...|       0.0|
|  0|      335.0|[0.93451526160190...|       0.0|
|  0|      220.0|[0.93451532673881...|       0.0|
|  0|      176.0|[0.93451535166074...|       0.0|
|  0|      158.0|[0.93451536185607...|       0.0|
|  0|      158.0|[0.93451536185607...|       0.0|
|  1|      149.0|[0.93451536695374...|       0.0|
|  0|      122.5|[0.93451538196353...|       0.0|
|  0|       99.0|[0.93451539527410...|       0.0|
|  0|       88.0|[0.93451540150458...|       0.0|
|  0|       79.0|[0.93451540660224...|       0.0|
|  0|       75.0|[0.93451540886787...|       0.0|
|  0|       68.0|[0.93451541283272...|       0.0|
|  0|       68.0|[0.93451541283272...|       0.0|
|  0|       59.9|[0.93451541742061...|       0.0|
|  0|      44.98|[0.93451542587140...|       0.0|
|  0|       35.5|[0.93451543124094...|       0.0|
|  0|       33.0|[0.93451543265696...|       0.0|
|  0|       32.8|[0.93451543277024...|       0.0|
|  0|       30.0|[0.93451543435618...|       0.0|
|  0|       28.0|[0.93451543548899...|       0.0|
|  0|       19.9|[0.93451544007688...|       0.0|
|  0|       19.8|[0.93451544013353...|       0.0|
|  0|       19.8|[0.93451544013353...|       0.0|
|  0|       12.0|[0.93451544455150...|       0.0|
|  0|        6.7|[0.93451544755345...|       0.0|
|  0|      568.0|[0.93458159339238...|       0.0|
|  0|      398.0|[0.93458168959099...|       0.0|
|  0|      158.0|[0.93458182540058...|       0.0|
|  0|      245.0|[0.93471518526899...|       0.0|
|  0|       99.0|[0.93471526772971...|       0.0|
|  0|       88.0|[0.93471527394249...|       0.0|
|  0|     1288.0|[0.93474589600376...|       0.0|
|  0|      688.0|[0.93474623473450...|       0.0|
|  0|      656.0|[0.93474625280009...|       0.0|
|  0|      568.0|[0.93474630248045...|       0.0|
|  0|      498.0|[0.93474634199889...|       0.0|
|  0|      399.0|[0.93474639788922...|       0.0|
|  0|      396.0|[0.93474639958287...|       0.0|
|  0|      298.0|[0.93474645490860...|       0.0|
|  0|      293.0|[0.93474645773134...|       0.0|
|  0|      209.0|[0.93474650515337...|       0.0|
|  0|      198.0|[0.93474651136339...|       0.0|
|  0|      198.0|[0.93474651136339...|       0.0|
|  0|      169.0|[0.93474652773527...|       0.0|
|  0|      168.0|[0.93474652829982...|       0.0|
|  0|      159.0|[0.93474653338074...|       0.0|
|  0|      155.0|[0.93474653563893...|       0.0|
|  0|      139.0|[0.93474654467169...|       0.0|
|  0|      138.0|[0.93474654523624...|       0.0|
|  0|      119.0|[0.93474655596264...|       0.0|
|  0|       99.0|[0.93474656725358...|       0.0|
|  0|       99.0|[0.93474656725358...|       0.0|
|  0|       88.0|[0.93474657346360...|       0.0|
|  0|       88.0|[0.93474657346360...|       0.0|
|  0|       79.0|[0.93474657854453...|       0.0|
|  0|       59.0|[0.93474658983547...|       0.0|
|  0|       59.0|[0.93474658983547...|       0.0|
|  0|       59.0|[0.93474658983547...|       0.0|
|  0|       58.0|[0.93474659040002...|       0.0|
|  0|       57.0|[0.93474659096456...|       0.0|
|  0|       49.8|[0.93474659502930...|       0.0|
|  0|      39.98|[0.93474660057315...|       0.0|
|  0|       36.8|[0.93474660236841...|       0.0|
|  0|       34.0|[0.93474660394914...|       0.0|
|  0|     6520.0|[0.93480919087761...|       0.0|
|  0|     3699.0|[0.93481078202537...|       0.0|
|  0|     1980.0|[0.93481175158689...|       0.0|
|  0|      660.0|[0.93481249609274...|       0.0|
|  0|      660.0|[0.93481249609274...|       0.0|
|  0|      398.0|[0.93481264386492...|       0.0|
|  0|      369.0|[0.93481266022137...|       0.0|
|  0|      299.0|[0.93481269970243...|       0.0|
|  0|      295.0|[0.93481270195849...|       0.0|
|  0|      278.0|[0.93481271154674...|       0.0|
|  0|      270.0|[0.93481271605886...|       0.0|
|  0|      228.0|[0.93481273974748...|       0.0|
|  0|      228.0|[0.93481273974748...|       0.0|
|  0|    11368.0|[0.93494253131370...|       0.0|
|  0|     9999.0|[0.93494330201510...|       0.0|
|  0|     1099.0|[0.93494360670448...|       0.0|
|  1|     8888.0|[0.93494392746484...|       0.0|
|  0|      338.0|[0.93494403511659...|       0.0|
|  0|      311.0|[0.93494405031645...|       0.0|
|  0|      300.0|[0.93494405650898...|       0.0|
|  0|      278.0|[0.93494406889404...|       0.0|
|  0|      188.0|[0.93494411956019...|       0.0|
|  0|      176.0|[0.93494412631568...|       0.0|
|  0|      168.0|[0.93494413081933...|       0.0|
|  0|      158.0|[0.93494413644890...|       0.0|
|  1|      138.0|[0.93494414770804...|       0.0|
|  0|      125.0|[0.93494415502647...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows

result_2.filter(result_2.clk==1).select("clk", "price", "probability", "prediction").sort("probability").show(100)
# 从该结果也可以看出,result_2的点击率预测率普遍要比result_1高出一点点

显示结果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  1|5.5555556E7|[0.92511743960350...|       0.0|
|  1|      159.0|[0.93239952905659...|       0.0|
|  1|      149.0|[0.93451536695374...|       0.0|
|  1|     8888.0|[0.93494392746484...|       0.0|
|  1|      138.0|[0.93494414770804...|       0.0|
|  1|       35.0|[0.93494420569256...|       0.0|
|  1|      519.0|[0.93494863870621...|       0.0|
|  1|      478.0|[0.93494866178596...|       0.0|
|  1|      349.0|[0.93494873440265...|       0.0|
|  1|      348.0|[0.93494873496557...|       0.0|
|  1|      316.0|[0.93494875297901...|       0.0|
|  1|      298.0|[0.93494876311156...|       0.0|
|  1|      298.0|[0.93494876311156...|       0.0|
|  1|      199.0|[0.93494881884058...|       0.0|
|  1|      199.0|[0.93494881884058...|       0.0|
|  1|      198.0|[0.93494881940350...|       0.0|
|  1|      187.1|[0.93494882553931...|       0.0|
|  1|      176.0|[0.93494883178772...|       0.0|
|  1|      168.0|[0.93494883629107...|       0.0|
|  1|      160.0|[0.93494884079442...|       0.0|
|  1|      158.0|[0.93494884192026...|       0.0|
|  1|      158.0|[0.93494884192026...|       0.0|
|  1|      135.0|[0.93494885486740...|       0.0|
|  1|      129.0|[0.93494885824491...|       0.0|
|  1|      127.0|[0.93494885937075...|       0.0|
|  1|      125.0|[0.93494886049659...|       0.0|
|  1|      124.0|[0.93494886105951...|       0.0|
|  1|      118.0|[0.93494886443702...|       0.0|
|  1|      109.0|[0.93494886950329...|       0.0|
|  1|      108.0|[0.93494887006621...|       0.0|
|  1|       99.0|[0.93494887513247...|       0.0|
|  1|       98.0|[0.93494887569539...|       0.0|
|  1|       79.8|[0.93494888594051...|       0.0|
|  1|       79.0|[0.93494888639085...|       0.0|
|  1|       77.0|[0.93494888751668...|       0.0|
|  1|       72.5|[0.93494889004982...|       0.0|
|  1|       69.0|[0.93494889202003...|       0.0|
|  1|       68.0|[0.93494889258295...|       0.0|
|  1|       60.0|[0.93494889708630...|       0.0|
|  1|      43.98|[0.93494890610426...|       0.0|
|  1|       40.0|[0.93494890834467...|       0.0|
|  1|       39.9|[0.93494890840096...|       0.0|
|  1|       39.6|[0.93494890856984...|       0.0|
|  1|       32.0|[0.93494891284802...|       0.0|
|  1|       31.0|[0.93494891341094...|       0.0|
|  1|      25.98|[0.93494891623679...|       0.0|
|  1|       23.0|[0.93494891791428...|       0.0|
|  1|       19.0|[0.93494892016596...|       0.0|
|  1|       16.9|[0.93494892134809...|       0.0|
|  1|       10.0|[0.93494892523222...|       0.0|
|  1|        3.5|[0.93494892889119...|       0.0|
|  1|        3.5|[0.93494892889119...|       0.0|
|  1|        0.4|[0.93494893063624...|       0.0|
|  1|     1288.0|[0.93501426059874...|       0.0|
|  1|      980.0|[0.93501443381533...|       0.0|
|  1|      788.0|[0.93501454179429...|       0.0|
|  1|      698.0|[0.93501459240937...|       0.0|
|  1|      695.0|[0.93501459409654...|       0.0|
|  1|      688.0|[0.93501459803326...|       0.0|
|  1|      599.0|[0.93501464808591...|       0.0|
|  1|      588.0|[0.93501465427219...|       0.0|
|  1|      516.0|[0.93501469476419...|       0.0|
|  1|      495.0|[0.93501470657436...|       0.0|
|  1|      398.0|[0.93501476112603...|       0.0|
|  1|      368.0|[0.93501477799768...|       0.0|
|  1|      339.0|[0.93501479430693...|       0.0|
|  1|      335.0|[0.93501479655648...|       0.0|
|  1|      324.0|[0.93501480274275...|       0.0|
|  1|      316.0|[0.93501480724185...|       0.0|
|  1|      299.0|[0.93501481680244...|       0.0|
|  1|      295.0|[0.93501481905199...|       0.0|
|  1|      279.0|[0.93501482805020...|       0.0|
|  1|      268.0|[0.93501483423646...|       0.0|
|  1|      259.0|[0.93501483929795...|       0.0|
|  1|      259.0|[0.93501483929795...|       0.0|
|  1|      249.0|[0.93501484492182...|       0.0|
|  1|      238.0|[0.93501485110809...|       0.0|
|  1|      199.0|[0.93501487304119...|       0.0|
|  1|      198.0|[0.93501487360358...|       0.0|
|  1|      179.0|[0.93501488428894...|       0.0|
|  1|      175.0|[0.93501488653849...|       0.0|
|  1|      129.0|[0.93501491240829...|       0.0|
|  1|      128.0|[0.93501491297068...|       0.0|
|  1|      118.0|[0.93501491859455...|       0.0|
|  1|      109.0|[0.93501492365603...|       0.0|
|  1|       98.0|[0.93501492984229...|       0.0|
|  1|       89.0|[0.93501493490377...|       0.0|
|  1|       79.0|[0.93501494052764...|       0.0|
|  1|       75.0|[0.93501494277718...|       0.0|
|  1|       69.8|[0.93501494570159...|       0.0|
|  1|       30.0|[0.93501496808458...|       0.0|
|  1|       15.0|[0.93501497652038...|       0.0|
|  1|      368.0|[0.93665387743951...|       0.0|
|  1|      198.0|[0.93665397079735...|       0.0|
|  1|      178.0|[0.93665398178062...|       0.0|
|  1|      158.0|[0.93665399276388...|       0.0|
|  1|      158.0|[0.93665399276388...|       0.0|
|  1|      149.0|[0.93665399770635...|       0.0|
|  1|       68.0|[0.93665404218855...|       0.0|
|  1|       36.0|[0.93665405976176...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows

五 离线推荐数据缓存

5.1离线数据缓存之离线召回集

  • 这里主要是利用我们前面训练的ALS模型进行协同过滤召回,但是注意,我们ALS模型召回的是用户最感兴趣的类别,而我们需要的是用户可能感兴趣的广告的集合,因此我们还需要根据召回的类别匹配出对应的广告。

    所以这里我们除了需要我们训练的ALS模型以外,还需要有一个广告和类别的对应关系。

# 从HDFS中加载广告基本信息数据,返回spark dafaframe对象
df = spark.read.csv("hdfs://localhost:8020/csv/ad_feature.csv", header=True)

# 注意:由于本数据集中存在NULL字样的数据,无法直接设置schema,只能先将NULL类型的数据处理掉,然后进行类型转换

from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替换掉NULL字符串,替换掉
df = df.replace("NULL", "-1")

# 更改df表结构:更改列类型和列名称
ad_feature_df = df.\
    withColumn("adgroup_id", df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", df.price.cast(FloatType()))

# 这里我们只需要adgroupId、和cateId
_ = ad_feature_df.select("adgroupId", "cateId")
# 由于这里数据集其实很少,所以我们再直接转成Pandas dataframe来处理,把数据载入内存
pdf = _.toPandas()


# 手动释放一些内存
del df
del ad_feature_df
del _
import gc
gc.collect()
  • 根据指定的类别找到对应的广告
import numpy as np
pdf.where(pdf.cateId==11156).dropna().adgroupId

np.random.choice(pdf.where(pdf.cateId==11156).dropna().adgroupId.astype(np.int64), 200)

显示结果:

313       138953.0
314       467512.0
1661      140008.0
1666      238772.0
1669      237471.0
1670      238761.0
			...   
843456    352273.0
846728    818681.0
846729    838953.0
846810    845337.0
Name: adgroupId, Length: 731, dtype: float64

  • 利用ALS模型进行类别的召回
# 加载als模型,注意必须先有spark上下文管理器,即sparkContext,但这里sparkSession创建后,自动创建了sparkContext

from pyspark.ml.recommendation import ALSModel
# 从hdfs加载之前存储的模型
als_model = ALSModel.load("hdfs://localhost:8020/models/userCateRatingALSModel.obj")
# 返回模型中关于用户的所有属性   df:   id   features
als_model.userFactors

显示结果:

DataFrame[id: int, features: array<float>]
import pandas as pd
cateId_df = pd.DataFrame(pdf.cateId.unique(),columns=["cateId"])
cateId_df

显示结果:

	cateId
0	1
1	2
2	3
3	4
4	5
5	6
6	7
...	...
6766	12948
6767	12955
6768	12960
6769 rows × 1 columns

cateId_df.insert(0, "userId", np.array([8 for i in range(6769)]))
cateId_df

显示结果:

 userId cateId
0	8	1
1	8	2
2	8	3
3	8	4
4	8	5
...	...	...
6766	8	12948
6767	8	12955
6768	8	12960
6769 rows × 2 columns

  • 传入 userid、cataId的df,对应预测值进行排序
als_model.transform(spark.createDataFrame(cateId_df)).sort("prediction", ascending=False).na.drop().show()

显示结果:

+------+------+----------+
|userId|cateId|prediction|
+------+------+----------+
|     8|  7214|  9.917084|
|     8|   877|  7.479664|
|     8|  7266| 7.4762917|
|     8| 10856| 7.3395424|
|     8|  4766|  7.149538|
|     8|  7282| 6.6835284|
|     8|  7270| 6.2145095|
|     8|   201| 6.0623236|
|     8|  4267| 5.9155636|
|     8|  7267|  5.838009|
|     8|  5392| 5.6882005|
|     8|  6261| 5.6804466|
|     8|  6306| 5.2992325|
|     8| 11050|  5.245261|
|     8|  8655| 5.1701374|
|     8|  4610|  5.139578|
|     8|   932|   5.12694|
|     8| 12276| 5.0776596|
|     8|  8071|  4.979195|
|     8|  6580| 4.8523283|
+------+------+----------+
only showing top 20 rows

import numpy as np
import pandas as pd

import redis

# 存储用户召回,使用redis第9号数据库,类型:sets类型
client = redis.StrictRedis(host="192.168.199.188", port=6379, db=9)

for r in als_model.userFactors.select("id").collect():
    
    userId = r.id
    
    cateId_df = pd.DataFrame(pdf.cateId.unique(),columns=["cateId"])
    cateId_df.insert(0, "userId", np.array([userId for i in range(6769)]))
    ret = set()
    
    # 利用模型,传入datasets(userId, cateId),这里控制了userId一样,所以相当于是在求某用户对所有分类的兴趣程度
    cateId_list = als_model.transform(spark.createDataFrame(cateId_df)).sort("prediction", ascending=False).na.drop()
    # 从前20个分类中选出500个进行召回
    for i in cateId_list.head(20):
        need = 500 - len(ret)    # 如果不足500个,那么随机选出need个广告
        ret = ret.union(np.random.choice(pdf.where(pdf.cateId==i.cateId).adgroupId.dropna().astype(np.int64), need))
        if len(ret) >= 500:    # 如果达到500个则退出
            break
    client.sadd(userId, *ret)
    
# 如果redis所在机器,内存不足,会抛出异常

5.2 离线数据缓存之离线特征

# "pid", 广告资源位,属于场景特征,也就是说,每一种广告通常是可以防止在多种资源外下的
# 因此这里对于pid,应该是由广告系统发起推荐请求时,向推荐系统明确要推荐的用户是谁,以及对应的资源位,或者说有哪些
# 这样如果有多个资源位,那么每个资源位都会对应相应的一个推荐列表

# 需要进行缓存的特征值
    
feature_cols_from_ad = [
    "price"    # 来自广告基本信息中
]

# 用户特征
feature_cols_from_user = [
    "cms_group_id",
    "final_gender_code",
    "age_level",
    "shopping_level",
    "occupation",
    "pvalue_level",
    "new_user_class_level"
]
  • 从HDFS中加载广告基本信息数据
_ad_feature_df = spark.read.csv("hdfs://localhost:9000/datasets/ad_feature.csv", header=True)

# 更改表结构,转换为对应的数据类型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替换掉NULL字符串
_ad_feature_df = _ad_feature_df.replace("NULL", "-1")
 
# 更改df表结构:更改列类型和列名称
ad_feature_df = _ad_feature_df.\
    withColumn("adgroup_id", _ad_feature_df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", _ad_feature_df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", _ad_feature_df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", _ad_feature_df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", _ad_feature_df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", _ad_feature_df.price.cast(FloatType()))
    
def foreachPartition(partition):
    
    import redis
    import json
    client = redis.StrictRedis(host="192.168.199.188", port=6379, db=10)
    
    for r in partition:
        data = {
    
    
            "price": r.price
        }
        # 转成json字符串再保存,能保证数据再次倒出来时,能有效的转换成python类型
        client.hset("ad_features", r.adgroupId, json.dumps(data))
        
ad_feature_df.foreachPartition(foreachPartition)
  • 从HDFS加载用户基本信息数据
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType

# 构建表结构schema对象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("cms_segid", IntegerType()),
    StructField("cms_group_id", IntegerType()),
    StructField("final_gender_code", IntegerType()),
    StructField("age_level", IntegerType()),
    StructField("pvalue_level", IntegerType()),
    StructField("shopping_level", IntegerType()),
    StructField("occupation", IntegerType()),
    StructField("new_user_class_level", IntegerType())
])
# 利用schema从hdfs加载
user_profile_df = spark.read.csv("hdfs://localhost:8020/csv/user_profile.csv", header=True, schema=schema)
user_profile_df

显示结果:

DataFrame[userId: int, cms_segid: int, cms_group_id: int, final_gender_code: int, age_level: int, pvalue_level: int, shopping_level: int, occupation: int, new_user_class_level: int]

def foreachPartition2(partition):
    
    import redis
    import json
    client = redis.StrictRedis(host="192.168.199.188", port=6379, db=10)
    
    for r in partition:
        data = {
    
    
            "cms_group_id": r.cms_group_id,
            "final_gender_code": r.final_gender_code,
            "age_level": r.age_level,
            "shopping_level": r.shopping_level,
            "occupation": r.occupation,
            "pvalue_level": r.pvalue_level,
            "new_user_class_level": r.new_user_class_level
        }
        # 转成json字符串再保存,能保证数据再次倒出来时,能有效的转换成python类型
        client.hset("user_features1", r.userId, json.dumps(data))
        
user_profile_df.foreachPartition(foreachPartition2)

六 实时产生推荐结果

6.1 推荐任务处理

  • CTR预测模型 + 特征 ==> 预测结果 ==> TOP-N列表

  • 热编码中:"pvalue_level"特征对应关系:

+------------+----------------------+
|pvalue_level|pl_onehot_feature     |
+------------+----------------------+
|          -1|                   0.0|
|           3|                   3.0|
|           1|                   2.0|
|           2|                   1.0|
+------------+----------------------+
  • “new_user_class_level”的特征对应关系:
+--------------------+------------------------+
|new_user_class_level|nucl_onehot_feature     |
+--------------------+------------------------+
|                  -1|                     0.0|
|                   3|                     2.0|
|                   1|                     4.0|
|                   4|                     3.0|
|                   2|                     1.0|
+--------------------+------------------------+
pvalue_level_rela = {
    
    -1: 0, 3:3, 1:2, 2:1}
new_user_class_level_rela = {
    
    -1:0, 3:2, 1:4, 4:3, 2:1}
  • "cms_group_id"特征对应关系:
+------------+-------------------------+
|cms_group_id|min(cms_group_id_feature)|
+------------+-------------------------+
|           7|                      9.0|
|          11|                      6.0|
|           3|                      0.0|
|           8|                      8.0|
|           0|                     12.0|
|           5|                      3.0|
|           6|                     10.0|
|           9|                      5.0|
|           1|                      7.0|
|          10|                      4.0|
|           4|                      1.0|
|          12|                     11.0|
|           2|                      2.0|
+------------+-------------------------+
cms_group_id_rela = {
    7: 9,
    11: 6,
    3: 0,
    8: 8,
    0: 12,
    5: 3,
    6: 10,
    9: 5,
    1: 7,
    10: 4,
    4: 1,
    12: 11,
    2: 2
}
  • "final_gender_code"特征对应关系:
+-----------------+------------------------------+
|final_gender_code|min(final_gender_code_feature)|
+-----------------+------------------------------+
|                1|                           1.0|
|                2|                           0.0|
+-----------------+------------------------------+
final_gender_code_rela = {1:1, 2:0}
  • "age_level"特征对应关系:
+---------+----------------------+
|age_level|min(age_level_feature)|
+---------+----------------------+
|        3|                   0.0|
|        0|                   6.0|
|        5|                   2.0|
|        6|                   5.0|
|        1|                   4.0|
|        4|                   1.0|
|        2|                   3.0|
+---------+----------------------+
age_level_rela = {3:0, 0:6, 5:2, 6:5, 1:4, 4:1, 2:3}
  • "shopping_level"特征对应关系:
|shopping_level|min(shopping_level_feature)|
+--------------+---------------------------+
|             3|                        0.0|
|             1|                        2.0|
|             2|                        1.0|
+--------------+---------------------------+
shopping_level_rela = {3:0, 1:2, 2:1}
  • "occupation"特征对应关系:
+----------+-----------------------+
|occupation|min(occupation_feature)|
+----------+-----------------------+
|         0|                    0.0|
|         1|                    1.0|
+----------+-----------------------+
occupation_rela = {0:0, 1:1}

pid_rela = {
    "430548_1007": 0, 
    "430549_1007": 1
}
  • 特征获取
import redis
import json
import pandas as pd
from pyspark.ml.linalg import DenseVector


def create_datasets(userId, pid):
    client_of_recall = redis.StrictRedis(host="192.168.199.88", port=6379, db=9)
    client_of_features = redis.StrictRedis(host="192.168.199.88", port=6379, db=10)
    # 获取用户特征
    user_feature = json.loads(client_of_features.hget("user_features", userId))
    
    # 获取用户召回集
    recall_sets = client_of_recall.smembers(userId)
    
    result = []
    
    # 遍历召回集
    for adgroupId in recall_sets:
        adgroupId = int(adgroupId)
        # 获取该广告的特征值
        ad_feature = json.loads(client_of_features.hget("ad_features", adgroupId))
        
        features = {
    
    }
        features.update(user_feature)
        features.update(ad_feature)

        for k,v in features.items():
            if v is None:
                features[k] = -1

        features_col = [
            # 特征值
            "price",
            "cms_group_id",
            "final_gender_code",
            "age_level",
            "shopping_level",
            "occupation",
            "pid", 
            "pvalue_level",
            "new_user_class_level"
        ]
        '''
        "cms_group_id", 类别型特征,约13个分类 ==> 13维
        "final_gender_code", 类别型特征,2个分类 ==> 2维
        "age_level", 类别型特征,7个分类 ==>7维
        "shopping_level", 类别型特征,3个分类 ==> 3维
        "occupation", 类别型特征,2个分类 ==> 2维
        '''

        price = float(features["price"])

        pid_value = [0 for i in range(2)]#[0,0]
        cms_group_id_value = [0 for i in range(13)]
        final_gender_code_value = [0 for i in range(2)]
        age_level_value = [0 for i in range(7)]
        shopping_level_value = [0 for i in range(3)]
        occupation_value = [0 for i in range(2)]
        pvalue_level_value = [0 for i in range(4)]
        new_user_class_level_value = [0 for i in range(5)]

        pid_value[pid_rela[pid]] = 1
        cms_group_id_value[cms_group_id_rela[int(features["cms_group_id"])]] = 1
        final_gender_code_value[final_gender_code_rela[int(features["final_gender_code"])]] = 1
        age_level_value[age_level_rela[int(features["age_level"])]] = 1
        shopping_level_value[shopping_level_rela[int(features["shopping_level"])]] = 1
        occupation_value[occupation_rela[int(features["occupation"])]] = 1
        pvalue_level_value[pvalue_level_rela[int(features["pvalue_level"])]] = 1
        new_user_class_level_value[new_user_class_level_rela[int(features["new_user_class_level"])]] = 1
 #         print(pid_value)
#         print(cms_group_id_value)
#         print(final_gender_code_value)
#         print(age_level_value)
#         print(shopping_level_value)
#         print(occupation_value)
#         print(pvalue_level_value)
#         print(new_user_class_level_value)
        
        vector = DenseVector([price] + pid_value + cms_group_id_value + final_gender_code_value\
        + age_level_value + shopping_level_value + occupation_value + pvalue_level_value + new_user_class_level_value)
        
        result.append((userId, adgroupId, vector))
        
    return result

# create_datasets(88, "430548_1007")
  • 载入训练好的模型
from pyspark.ml.classification import LogisticRegressionModel
CTR_model = LogisticRegressionModel.load("hdfs://localhost:9000/models/CTRModel_AllOneHot.obj")
pdf = pd.DataFrame(create_datasets(8, "430548_1007"), columns=["userId", "adgroupId", "features"])
datasets = spark.createDataFrame(pdf)
datasets.show()

显示结果:

+------+---------+--------------------+
|userId|adgroupId|            features|
+------+---------+--------------------+
|     8|   445914|[9.89999961853027...|
|     8|   258252|[7.59999990463256...|
|     8|   129682|[8.5,1.0,0.0,1.0,...|
|     8|   763027|[68.0,1.0,0.0,1.0...|
|     8|   292027|[16.0,1.0,0.0,1.0...|
|     8|   430023|[34.2000007629394...|
|     8|   133457|[169.0,1.0,0.0,1....|
|     8|   816999|[5.0,1.0,0.0,1.0,...|
|     8|   221714|[4.80000019073486...|
|     8|   186334|[106.0,1.0,0.0,1....|
|     8|   169717|[2.20000004768371...|
|     8|    31314|[15.8000001907348...|
|     8|   815312|[2.29999995231628...|
|     8|   199445|[5.0,1.0,0.0,1.0,...|
|     8|   746178|[16.7999992370605...|
|     8|   290950|[6.5,1.0,0.0,1.0,...|
|     8|   221585|[18.5,1.0,0.0,1.0...|
|     8|   692672|[47.0,1.0,0.0,1.0...|
|     8|   797982|[33.0,1.0,0.0,1.0...|
|     8|   815219|[2.40000009536743...|
+------+---------+--------------------+
only showing top 20 rows
prediction = CTR_model.transform(datasets).sort("probability")
prediction.show()
+------+---------+--------------------+--------------------+--------------------+----------+
|userId|adgroupId|            features|       rawPrediction|         probability|prediction|
+------+---------+--------------------+--------------------+--------------------+----------+
|     8|   631204|[19888.0,1.0,0.0,...|[2.69001234046578...|[0.93643471623189...|       0.0|
|     8|   583215|[3750.0,1.0,0.0,1...|[2.69016170680037...|[0.93644360664433...|       0.0|
|     8|   275819|[3280.0,1.0,0.0,1...|[2.69016605691669...|[0.93644386554961...|       0.0|
|     8|   401433|[1200.0,1.0,0.0,1...|[2.69018530849532...|[0.93644501133142...|       0.0|
|     8|    29466|[640.0,1.0,0.0,1....|[2.69019049161265...|[0.93644531980785...|       0.0|
|     8|   173327|[356.0,1.0,0.0,1....|[2.69019312019358...|[0.93644547624893...|       0.0|
|     8|   241402|[269.0,1.0,0.0,1....|[2.69019392542787...|[0.93644552417271...|       0.0|
|     8|   351366|[246.0,1.0,0.0,1....|[2.69019413830591...|[0.93644553684221...|       0.0|
|     8|   229827|[238.0,1.0,0.0,1....|[2.69019421235044...|[0.93644554124900...|       0.0|
|     8|   164807|[228.0,1.0,0.0,1....|[2.69019430490611...|[0.93644554675747...|       0.0|
|     8|   227731|[199.0,1.0,0.0,1....|[2.69019457331754...|[0.93644556273205...|       0.0|
|     8|   265403|[198.0,1.0,0.0,1....|[2.69019458257311...|[0.93644556328290...|       0.0|
|     8|   569939|[188.0,1.0,0.0,1....|[2.69019467512877...|[0.93644556879138...|       0.0|
|     8|   277335|[181.5,1.0,0.0,1....|[2.69019473528996...|[0.93644557237189...|       0.0|
|     8|   575633|[180.0,1.0,0.0,1....|[2.69019474917331...|[0.93644557319816...|       0.0|
|     8|   201867|[179.0,1.0,0.0,1....|[2.69019475842887...|[0.93644557374900...|       0.0|
|     8|    25542|[176.0,1.0,0.0,1....|[2.69019478619557...|[0.93644557540155...|       0.0|
|     8|   133457|[169.0,1.0,0.0,1....|[2.69019485098454...|[0.93644557925748...|       0.0|
|     8|   494224|[169.0,1.0,0.0,1....|[2.69019485098454...|[0.93644557925748...|       0.0|
|     8|   339382|[163.0,1.0,0.0,1....|[2.69019490651794...|[0.93644558256256...|       0.0|
+------+---------+--------------------+--------------------+--------------------+----------+
only showing top 20 rows
  • TOP-20
# TOP-20
prediction.select("adgroupId").head(20)

显示结果:

[Row(adgroupId=631204),
 Row(adgroupId=583215),
 Row(adgroupId=275819),
 Row(adgroupId=401433),
 Row(adgroupId=29466),
 Row(adgroupId=173327),
 Row(adgroupId=241402),
 Row(adgroupId=351366),
 Row(adgroupId=229827),
 Row(adgroupId=164807),
 Row(adgroupId=227731),
 Row(adgroupId=265403),
 Row(adgroupId=569939),
 Row(adgroupId=277335),
 Row(adgroupId=575633),
 Row(adgroupId=201867),
 Row(adgroupId=25542),
 Row(adgroupId=133457),
 Row(adgroupId=494224),
 Row(adgroupId=339382)]
[i.adgroupId for i in prediction.select("adgroupId").head(20)]

显示结果:

[631204,
 583215,
 275819,
 401433,
 29466,
 173327,
 241402,
 351366,
 229827,
 164807,
 227731,
 265403,
 569939,
 277335,
 575633,
 201867,
 25542,
 133457,
 494224,
 339382]

,1.0,…|
| 8| 763027|[68.0,1.0,0.0,1.0…|
| 8| 292027|[16.0,1.0,0.0,1.0…|
| 8| 430023|[34.2000007629394…|
| 8| 133457|[169.0,1.0,0.0,1…|
| 8| 816999|[5.0,1.0,0.0,1.0,…|
| 8| 221714|[4.80000019073486…|
| 8| 186334|[106.0,1.0,0.0,1…|
| 8| 169717|[2.20000004768371…|
| 8| 31314|[15.8000001907348…|
| 8| 815312|[2.29999995231628…|
| 8| 199445|[5.0,1.0,0.0,1.0,…|
| 8| 746178|[16.7999992370605…|
| 8| 290950|[6.5,1.0,0.0,1.0,…|
| 8| 221585|[18.5,1.0,0.0,1.0…|
| 8| 692672|[47.0,1.0,0.0,1.0…|
| 8| 797982|[33.0,1.0,0.0,1.0…|
| 8| 815219|[2.40000009536743…|
±-----±--------±-------------------+
only showing top 20 rows


prediction = CTR_model.transform(datasets).sort(“probability”)
prediction.show()


```shell
+------+---------+--------------------+--------------------+--------------------+----------+
|userId|adgroupId|            features|       rawPrediction|         probability|prediction|
+------+---------+--------------------+--------------------+--------------------+----------+
|     8|   631204|[19888.0,1.0,0.0,...|[2.69001234046578...|[0.93643471623189...|       0.0|
|     8|   583215|[3750.0,1.0,0.0,1...|[2.69016170680037...|[0.93644360664433...|       0.0|
|     8|   275819|[3280.0,1.0,0.0,1...|[2.69016605691669...|[0.93644386554961...|       0.0|
|     8|   401433|[1200.0,1.0,0.0,1...|[2.69018530849532...|[0.93644501133142...|       0.0|
|     8|    29466|[640.0,1.0,0.0,1....|[2.69019049161265...|[0.93644531980785...|       0.0|
|     8|   173327|[356.0,1.0,0.0,1....|[2.69019312019358...|[0.93644547624893...|       0.0|
|     8|   241402|[269.0,1.0,0.0,1....|[2.69019392542787...|[0.93644552417271...|       0.0|
|     8|   351366|[246.0,1.0,0.0,1....|[2.69019413830591...|[0.93644553684221...|       0.0|
|     8|   229827|[238.0,1.0,0.0,1....|[2.69019421235044...|[0.93644554124900...|       0.0|
|     8|   164807|[228.0,1.0,0.0,1....|[2.69019430490611...|[0.93644554675747...|       0.0|
|     8|   227731|[199.0,1.0,0.0,1....|[2.69019457331754...|[0.93644556273205...|       0.0|
|     8|   265403|[198.0,1.0,0.0,1....|[2.69019458257311...|[0.93644556328290...|       0.0|
|     8|   569939|[188.0,1.0,0.0,1....|[2.69019467512877...|[0.93644556879138...|       0.0|
|     8|   277335|[181.5,1.0,0.0,1....|[2.69019473528996...|[0.93644557237189...|       0.0|
|     8|   575633|[180.0,1.0,0.0,1....|[2.69019474917331...|[0.93644557319816...|       0.0|
|     8|   201867|[179.0,1.0,0.0,1....|[2.69019475842887...|[0.93644557374900...|       0.0|
|     8|    25542|[176.0,1.0,0.0,1....|[2.69019478619557...|[0.93644557540155...|       0.0|
|     8|   133457|[169.0,1.0,0.0,1....|[2.69019485098454...|[0.93644557925748...|       0.0|
|     8|   494224|[169.0,1.0,0.0,1....|[2.69019485098454...|[0.93644557925748...|       0.0|
|     8|   339382|[163.0,1.0,0.0,1....|[2.69019490651794...|[0.93644558256256...|       0.0|
+------+---------+--------------------+--------------------+--------------------+----------+
only showing top 20 rows
  • TOP-20
# TOP-20
prediction.select("adgroupId").head(20)

显示结果:

[Row(adgroupId=631204),
 Row(adgroupId=583215),
 Row(adgroupId=275819),
 Row(adgroupId=401433),
 Row(adgroupId=29466),
 Row(adgroupId=173327),
 Row(adgroupId=241402),
 Row(adgroupId=351366),
 Row(adgroupId=229827),
 Row(adgroupId=164807),
 Row(adgroupId=227731),
 Row(adgroupId=265403),
 Row(adgroupId=569939),
 Row(adgroupId=277335),
 Row(adgroupId=575633),
 Row(adgroupId=201867),
 Row(adgroupId=25542),
 Row(adgroupId=133457),
 Row(adgroupId=494224),
 Row(adgroupId=339382)]
[i.adgroupId for i in prediction.select("adgroupId").head(20)]

显示结果:

[631204,
 583215,
 275819,
 401433,
 29466,
 173327,
 241402,
 351366,
 229827,
 164807,
 227731,
 265403,
 569939,
 277335,
 575633,
 201867,
 25542,
 133457,
 494224,
 339382]

猜你喜欢

转载自blog.csdn.net/fegus/article/details/124445822
今日推荐