Spark SQL 项目:实现各区域热门商品前N统计

一. 需求

1.1 需求简介
这里的热门商品是从点击量的维度来看的.

计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。




1.2 思路分析
使用 sql 来完成. 碰到复杂的需求, 可以使用 udf 或 udaf

查询出来所有的点击记录, 并与 city_info 表连接, 得到每个城市所在的地区. 与 Product_info 表连接得到产品名称
按照地区和商品 id 分组, 统计出每个商品在每个地区的总点击次数
每个地区内按照点击次数降序排列
只取前三名. 并把结果保存在数据库中
城市备注需要自定义 UDAF 函数



二. 实际操作
1. 准备数据
  我们这次 Spark-sql 操作中所有的数据均来自 Hive.

  首先在 Hive 中创建表, 并导入数据.

  一共有 3 张表: 1 张用户行为表, 1 张城市表, 1 张产品表

1. 打开Hive




2. 创建三个表

 
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
CREATE TABLE `user_visit_action`(
  `date` string,
  `user_id` bigint,
  `session_id` string,
  `page_id` bigint,
  `action_time` string,
  `search_keyword` string,
  `click_category_id` bigint,
  `click_product_id` bigint,
  `order_category_ids` string,
  `order_product_ids` string,
  `pay_category_ids` string,
  `pay_product_ids` string,
  `city_id` bigint)
row format delimited fields terminated by '\t';
 
CREATE TABLE `product_info`(
  `product_id` bigint,
  `product_name` string,
  `extend_info` string)
row format delimited fields terminated by '\t';
 
CREATE TABLE `city_info`(
  `city_id` bigint,
  `city_name` string,
  `area` string)
row format delimited fields terminated by '\t';





3. 上传数据

 
1
2
3
load data local inpath '/opt/module/datas/user_visit_action.txt' into table spark0806.user_visit_action;
load data local inpath '/opt/module/datas/product_info.txt' into table spark0806.product_info;
load data local inpath '/opt/module/datas/city_info.txt' into table spark0806.city_info;





4. 测试是否上传成功

 
1
hive> select * from city_info;





2. 显示各区域热门商品 Top3

 
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// user_visit_action  product_info  city_info
 
1. 先把需要的字段查出来   t1
select
    ci.*,
    pi.product_name,
    click_product_id
from user_visit_action uva
join product_info pi on uva.click_product_id=pi.product_id
join city_info ci on uva.city_id=ci.city_id
 
2. 按照地区和商品名称聚合
select
    area,
    product_name,
    count(*)  count
from t1
group by area , product_name
 
3. 按照地区进行分组开窗 排序 开窗函数 t3 // (rank(1 2 2 4 5...) row_number(1 2 3 4...) dense_rank(1 2 2 3 4...))
select
    area,
    product_name,
    count,
    rank() over(partition by area order by count desc)
from  t2
 
 
4. 过滤出来名次小于等于3的
select
    area,
    product_name,
    count
from  t3
where rk <=3



2. 运行结果




3. 定义udaf函数 得到需求结果

 
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
package com.buwenbuhuo.spark.sql.project
 
import java.text.DecimalFormat
 
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
 
 
/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-06 13:24
 **
 *         MyCSDN :  [url=https://buwenbuhuo.blog.csdn.net/]https://buwenbuhuo.blog.csdn.net/[/url]
 *
 */
class CityRemarkUDAF extends UserDefinedAggregateFunction {
  // 输入数据的类型:  北京  String
  override def inputSchema: StructType = {
    StructType(Array(StructField("city", StringType)))
  }
 
  // 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000  Map,  总的点击量  1000/?
  override def bufferSchema: StructType = {
    StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
  }
 
  // 输出的数据类型  "北京21.2%,天津13.2%,其他65.6%"  String
  override def dataType: DataType = StringType
 
  // 相同的输入是否应用有相同的输出.
  override def deterministic: Boolean = true
 
  // 给存储数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map[String, Long]()
    // 初始化总的点击量
    buffer(1) = 0L
  }
 
  // 分区内合并 Map[城市名, 点击量]
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(cityName: String) =>
        // 1. 总的点击量 + 1
        buffer(1) = buffer.getLong(1) + 1L
        // 2. 给这个城市的点击量 +1 =>   找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
        val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
        buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
      case _ =>
    }
  }
 
  // 分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[String, Long]](0)
    val map2 = buffer2.getAs[Map[String, Long]](0)
 
    val total1: Long = buffer1.getLong(1)
    val total2: Long = buffer2.getLong(1)
 
    // 1. 总数的聚合
    buffer1(1) = total1 + total2
    // 2. map的聚合
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (cityName, count)) =>
        map + (cityName -> (map.getOrElse(cityName, 0L) + count))
    }
 
  }
 
  // 最终的输出结果
  override def evaluate(buffer: Row): Any = {
    // "北京21.2%,天津13.2%,其他65.6%"
    val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0)
    val total: Long = buffer.getLong(1)
 
    val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
    var cityRemarks: List[CityRemark] = cityCountTop2.map {
      case (cityName, count) => CityRemark(cityName, count.toDouble / total)
    }
//    CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
    cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
    cityRemarks.mkString(",")
  }
}
 
case class CityRemark(cityName: String, cityRatio: Double) {
  val formatter = new DecimalFormat("0.00%")
 
  override def toString: String = s"$cityName:${formatter.format(cityRatio)}"
 
}



运行结果




4 .保存到Mysql

1. 源码

 
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
val props: Properties = new Properties()
props.put("user","root")
props.put("password","199712")
spark.sql(
  """
    |select
    |    area,
    |    product_name,
    |    count,
    |    remark
    |from t3
    |where rk<=3
    |""".stripMargin)
  .coalesce(1)
  .write
  .mode("overwrite")
  .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)



2.运行结果





三. 完整代码

1. udaf

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
package com.buwenbuhuo.spark.sql.project
 
import java.text.DecimalFormat
 
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
 
 
/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-06 13:24
 **
 *         MyCSDN :  [url=https://buwenbuhuo.blog.csdn.net/]https://buwenbuhuo.blog.csdn.net/[/url]
 *
 */
class CityRemarkUDAF extends UserDefinedAggregateFunction {
  // 输入数据的类型:  北京  String
  override def inputSchema: StructType = {
    StructType(Array(StructField("city", StringType)))
  }
 
  // 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000  Map,  总的点击量  1000/?
  override def bufferSchema: StructType = {
    StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
  }
 
  // 输出的数据类型  "北京21.2%,天津13.2%,其他65.6%"  String
  override def dataType: DataType = StringType
 
  // 相同的输入是否应用有相同的输出.
  override def deterministic: Boolean = true
 
  // 给存储数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map[String, Long]()
    // 初始化总的点击量
    buffer(1) = 0L
  }
 
  // 分区内合并 Map[城市名, 点击量]
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(cityName: String) =>
        // 1. 总的点击量 + 1
        buffer(1) = buffer.getLong(1) + 1L
        // 2. 给这个城市的点击量 +1 =>   找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
        val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
        buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
      case _ =>
    }
  }
 
  // 分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[String, Long]](0)
    val map2 = buffer2.getAs[Map[String, Long]](0)
 
    val total1: Long = buffer1.getLong(1)
    val total2: Long = buffer2.getLong(1)
 
    // 1. 总数的聚合
    buffer1(1) = total1 + total2
    // 2. map的聚合
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (cityName, count)) =>
        map + (cityName -> (map.getOrElse(cityName, 0L) + count))
    }
 
  }
 
  // 最终的输出结果
  override def evaluate(buffer: Row): Any = {
    // "北京21.2%,天津13.2%,其他65.6%"
    val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0)
    val total: Long = buffer.getLong(1)
 
    val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
    var cityRemarks: List[CityRemark] = cityCountTop2.map {
      case (cityName, count) => CityRemark(cityName, count.toDouble / total)
    }
//    CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
    cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
    cityRemarks.mkString(",")
  }
}
 
case class CityRemark(cityName: String, cityRatio: Double) {
  val formatter = new DecimalFormat("0.00%")
 
  override def toString: String = s"$cityName:${formatter.format(cityRatio)}"
 
}



2. 主程序(具体实现)

 

 
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package com.buwenbuhuo.spark.sql.project
 
import java.util.Properties
 
import org.apache.spark.sql.SparkSession
 
/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-05 19:01
 **
 *         MyCSDN :  [url=https://buwenbuhuo.blog.csdn.net/]https://buwenbuhuo.blog.csdn.net/[/url]
 *
 */
object SqlApp {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .master("local")
      .appName("SqlApp")
      .enableHiveSupport()
      .getOrCreate()
    import spark.implicits._
    spark.udf.register("remark",new CityRemarkUDAF)
 
    // 去执行sql,从hive查询数据
    spark.sql("use spark0806")
    spark.sql(
      """
        |select
        |    ci.*,
        |    pi.product_name,
        |    uva.click_product_id
        |from user_visit_action uva
        |join product_info pi on uva.click_product_id=pi.product_id
        |join city_info ci on uva.city_id=ci.city_id
        |
        |""".stripMargin).createOrReplaceTempView("t1")
 
    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count(*) count,
        |    remark(city_name) remark
        |from t1
        |group by area, product_name
        |""".stripMargin).createOrReplaceTempView("t2")
 
    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark,
        |    rank() over(partition by area order by count desc) rk
        |from t2
        |""".stripMargin).createOrReplaceTempView("t3")
 
    val props: Properties = new Properties()
    props.put("user","root")
    props.put("password","199712")
    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark
        |from t3
        |where rk<=3
        |""".stripMargin)
      .coalesce(1)
      .write
      .mode("overwrite")
      .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)
 
 
    // 把结果写入到mysql中
 
    spark.close()
  }
}

猜你喜欢

转载自blog.csdn.net/ytp552200ytp/article/details/108076903