Spark SQL 项目实战 | 计算各区域热门商品 Top3

  大家好,我是不温卜火,是一名计算机学院大数据专业大二的学生,昵称来源于成语—不温不火,本意是希望自己性情温和。作为一名互联网行业的小白,博主写博客一方面是为了记录自己的学习过程,另一方面是总结自己所犯的错误希望能够帮助到很多和自己一样处于起步阶段的萌新。但由于水平有限,博客中难免会有一些错误出现,有纰漏之处恳请各位大佬不吝赐教!暂时只有csdn这一个平台,博客主页:https://buwenbuhuo.blog.csdn.net/

  本片博文为大家带来的是计算各区域热门商品 Top3。
1


2

一. 需求

1.1 需求简介

这里的热门商品是从点击量的维度来看的.

计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。
8

1.2 思路分析

使用 sql 来完成. 碰到复杂的需求, 可以使用 udf 或 udaf

  1. 查询出来所有的点击记录, 并与 city_info 表连接, 得到每个城市所在的地区. 与 Product_info 表连接得到产品名称
  2. 按照地区和商品 id 分组, 统计出每个商品在每个地区的总点击次数
  3. 每个地区内按照点击次数降序排列
  4. 只取前三名. 并把结果保存在数据库中
  5. 城市备注需要自定义 UDAF 函数

二. 实际操作

1. 准备数据

  我们这次 Spark-sql 操作中所有的数据均来自 Hive.

  首先在 Hive 中创建表, 并导入数据.

  一共有 3 张表: 1 张用户行为表, 1 张城市表, 1 张产品表

  • 1. 打开Hive

3

  • 2. 创建三个表
CREATE TABLE `user_visit_action`(
  `date` string,
  `user_id` bigint,
  `session_id` string,
  `page_id` bigint,
  `action_time` string,
  `search_keyword` string,
  `click_category_id` bigint,
  `click_product_id` bigint,
  `order_category_ids` string,
  `order_product_ids` string,
  `pay_category_ids` string,
  `pay_product_ids` string,
  `city_id` bigint)
row format delimited fields terminated by '\t';

CREATE TABLE `product_info`(
  `product_id` bigint,
  `product_name` string,
  `extend_info` string)
row format delimited fields terminated by '\t';

CREATE TABLE `city_info`(
  `city_id` bigint,
  `city_name` string,
  `area` string)
row format delimited fields terminated by '\t';

6

  • 3. 上传数据
load data local inpath '/opt/module/datas/user_visit_action.txt' into table spark0806.user_visit_action;
load data local inpath '/opt/module/datas/product_info.txt' into table spark0806.product_info;
load data local inpath '/opt/module/datas/city_info.txt' into table spark0806.city_info;

7

  • 4. 测试是否上传成功
hive> select * from city_info;

6

2. 显示各区域热门商品 Top3

  • 1. 源码

// user_visit_action  product_info  city_info

1. 先把需要的字段查出来   t1
select
    ci.*,
    pi.product_name,
    click_product_id
from user_visit_action uva
join product_info pi on uva.click_product_id=pi.product_id
join city_info ci on uva.city_id=ci.city_id

2. 按照地区和商品名称聚合
select
    area,
    product_name,
    count(*)  count
from t1
group by area , product_name

3. 按照地区进行分组开窗 排序 开窗函数 t3 // (rank(1 2 2 4 5...) row_number(1 2 3 4...) dense_rank(1 2 2 3 4...))
select
    area,
    product_name,
    count,
    rank() over(partition by area order by count desc)
from  t2


4. 过滤出来名次小于等于3的
select
    area,
    product_name,
    count
from  t3
where rk <=3
  • 2. 运行结果

7

3. 定义udaf函数 得到需求结果

  • 1. 源码
package com.buwenbuhuo.spark.sql.project

import java.text.DecimalFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-06 13:24
 **
 *         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
class CityRemarkUDAF extends UserDefinedAggregateFunction {
  // 输入数据的类型:  北京  String
  override def inputSchema: StructType = {
    StructType(Array(StructField("city", StringType)))
  }

  // 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000  Map,  总的点击量  1000/?
  override def bufferSchema: StructType = {
    StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
  }

  // 输出的数据类型  "北京21.2%,天津13.2%,其他65.6%"  String
  override def dataType: DataType = StringType

  // 相同的输入是否应用有相同的输出.
  override def deterministic: Boolean = true

  // 给存储数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map[String, Long]()
    // 初始化总的点击量
    buffer(1) = 0L
  }

  // 分区内合并 Map[城市名, 点击量]
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(cityName: String) =>
        // 1. 总的点击量 + 1
        buffer(1) = buffer.getLong(1) + 1L
        // 2. 给这个城市的点击量 +1 =>   找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
        val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
        buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
      case _ =>
    }
  }

  // 分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[String, Long]](0)
    val map2 = buffer2.getAs[Map[String, Long]](0)

    val total1: Long = buffer1.getLong(1)
    val total2: Long = buffer2.getLong(1)

    // 1. 总数的聚合
    buffer1(1) = total1 + total2
    // 2. map的聚合
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (cityName, count)) =>
        map + (cityName -> (map.getOrElse(cityName, 0L) + count))
    }

  }

  // 最终的输出结果
  override def evaluate(buffer: Row): Any = {
    // "北京21.2%,天津13.2%,其他65.6%"
    val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0)
    val total: Long = buffer.getLong(1)

    val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
    var cityRemarks: List[CityRemark] = cityCountTop2.map {
      case (cityName, count) => CityRemark(cityName, count.toDouble / total)
    }
//    CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
    cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
    cityRemarks.mkString(",")
  }
}

case class CityRemark(cityName: String, cityRatio: Double) {
  val formatter = new DecimalFormat("0.00%")

  override def toString: String = s"$cityName:${formatter.format(cityRatio)}"

}

  • 运行结果

9

4 .保存到Mysql

  • 1. 源码
    val props: Properties = new Properties()
    props.put("user","root")
    props.put("password","199712")
    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark
        |from t3
        |where rk<=3
        |""".stripMargin)
      .coalesce(1)
      .write
      .mode("overwrite")
      .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)
  • 2.运行结果

10

三. 完整代码

  • 1. udaf
package com.buwenbuhuo.spark.sql.project

import java.text.DecimalFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-06 13:24
 **
 *         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
class CityRemarkUDAF extends UserDefinedAggregateFunction {
  // 输入数据的类型:  北京  String
  override def inputSchema: StructType = {
    StructType(Array(StructField("city", StringType)))
  }

  // 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000  Map,  总的点击量  1000/?
  override def bufferSchema: StructType = {
    StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
  }

  // 输出的数据类型  "北京21.2%,天津13.2%,其他65.6%"  String
  override def dataType: DataType = StringType

  // 相同的输入是否应用有相同的输出.
  override def deterministic: Boolean = true

  // 给存储数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map[String, Long]()
    // 初始化总的点击量
    buffer(1) = 0L
  }

  // 分区内合并 Map[城市名, 点击量]
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(cityName: String) =>
        // 1. 总的点击量 + 1
        buffer(1) = buffer.getLong(1) + 1L
        // 2. 给这个城市的点击量 +1 =>   找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
        val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
        buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
      case _ =>
    }
  }

  // 分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[String, Long]](0)
    val map2 = buffer2.getAs[Map[String, Long]](0)

    val total1: Long = buffer1.getLong(1)
    val total2: Long = buffer2.getLong(1)

    // 1. 总数的聚合
    buffer1(1) = total1 + total2
    // 2. map的聚合
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (cityName, count)) =>
        map + (cityName -> (map.getOrElse(cityName, 0L) + count))
    }

  }

  // 最终的输出结果
  override def evaluate(buffer: Row): Any = {
    // "北京21.2%,天津13.2%,其他65.6%"
    val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0)
    val total: Long = buffer.getLong(1)

    val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
    var cityRemarks: List[CityRemark] = cityCountTop2.map {
      case (cityName, count) => CityRemark(cityName, count.toDouble / total)
    }
//    CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
    cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
    cityRemarks.mkString(",")
  }
}

case class CityRemark(cityName: String, cityRatio: Double) {
  val formatter = new DecimalFormat("0.00%")

  override def toString: String = s"$cityName:${formatter.format(cityRatio)}"

}

  • 2. 主程序(具体实现)
package com.buwenbuhuo.spark.sql.project

import java.util.Properties

import org.apache.spark.sql.SparkSession

/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-05 19:01
 **
 *         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
object SqlApp {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .master("local[*]")
      .appName("SqlApp")
      .enableHiveSupport()
      .getOrCreate()
    import spark.implicits._
    spark.udf.register("remark",new CityRemarkUDAF)

    // 去执行sql,从hive查询数据
    spark.sql("use spark0806")
    spark.sql(
      """
        |select
        |    ci.*,
        |    pi.product_name,
        |    uva.click_product_id
        |from user_visit_action uva
        |join product_info pi on uva.click_product_id=pi.product_id
        |join city_info ci on uva.city_id=ci.city_id
        |
        |""".stripMargin).createOrReplaceTempView("t1")

    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count(*) count,
        |    remark(city_name) remark
        |from t1
        |group by area, product_name
        |""".stripMargin).createOrReplaceTempView("t2")

    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark,
        |    rank() over(partition by area order by count desc) rk
        |from t2
        |""".stripMargin).createOrReplaceTempView("t3")

    val props: Properties = new Properties()
    props.put("user","root")
    props.put("password","199712")
    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark
        |from t3
        |where rk<=3
        |""".stripMargin)
      .coalesce(1)
      .write
      .mode("overwrite")
      .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)


    // 把结果写入到mysql中

    spark.close()
  }
}


  本次的分享就到这里了,


14

  好书不厌读百回,熟读课思子自知。而我想要成为全场最靓的仔,就必须坚持通过学习来获取更多知识,用知识改变命运,用博客见证成长,用行动证明我在努力。
  如果我的博客对你有帮助、如果你喜欢我的博客内容,请“点赞” “评论”“收藏”一键三连哦!听说点赞的人运气不会太差,每一天都会元气满满呦!如果实在要白嫖的话,那祝你开心每一天,欢迎常来我博客看看。
  码字不易,大家的支持就是我坚持下去的动力。点赞后不要忘了关注我哦!

15

16

猜你喜欢

转载自blog.csdn.net/qq_16146103/article/details/107824095