数据源
链接:https://pan.baidu.com/s/1lUbGmA10yOgUL4Rz2KAGmw
提取码:yh57
源码在github:https://github.com/lidonglin-bit/Spark-Sql
一.数据准备
我们这次 Spark-sql 操作中所有的数据均来自 Hive.
首先在 Hive 中创建表, 并导入数据.
一共有 3 张表: 1 张用户行为表, 1 张城市表, 1 张产品表
CREATE TABLE `user_visit_action`(
`date` string,
`user_id` bigint,
`session_id` string,
`page_id` bigint,
`action_time` string,
`search_keyword` string,
`click_category_id` bigint,
`click_product_id` bigint,
`order_category_ids` string,
`order_product_ids` string,
`pay_category_ids` string,
`pay_product_ids` string,
`city_id` bigint)
row format delimited fields terminated by '\t';
load data local inpath '/export/servers/datas/user_visit_action.txt' into table spark1602.user_visit_action;
CREATE TABLE `product_info`(
`product_id` bigint,
`product_name` string,
`extend_info` string)
row format delimited fields terminated by '\t';
load data local inpath '/export/servers/datas/product_info.txt' into table spark1602.product_info;
CREATE TABLE `city_info`(
`city_id` bigint,
`city_name` string,
`area` string)
row format delimited fields terminated by '\t';
load data local inpath '/export/servers/datas/city_info.txt' into table spark1602.city_info;
二.各区域热门商品 Top3
需求简介
这里的热门商品是从点击量的维度来看的.
计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。
例如:
地区 商品名称 点击次数 城市备注
华北 商品A 100000 北京21.2%,天津13.2%,其他65.6%
华北 商品P 80200 北京63.0%,太原10%,其他27.0%
华北 商品M 40000 北京63.0%,太原10%,其他27.0%
东北 商品J 92000 大连28%,辽宁17.0%,其他 55.0%
思路分析
使用 sql 来完成. 碰到复杂的需求, 可以使用 udf 或 udaf
1.先把需要的字段查出来
2.按照地区和商品名称聚合
3.按照地区进行分组开窗,排序 开窗函数
4.过滤出来名次小于等于3的
5. 城市备注需要自定义 UDAF 函数
具体实现
提前准备
- 1.添加依赖
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.27</version>
</dependency>
</dependencies>
<build>
<plugins>
<!-- 打包插件, 否则 scala 类不会编译并打包进去 -->
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.4.6</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
- 2.本次需要用hive,把hive-site.xml文件导入到resources下
测试数据(实现一小部分sql)
实现前面的一部分,后部分要用UDAF
import org.apache.spark.sql.SparkSession
object SqlApp {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "root")
val spark = SparkSession
.builder()
.master("local[*]")
.appName("SqlApp")
.enableHiveSupport()
.getOrCreate()
//去执行sql,从hive查询数据
spark.sql("use spark1602")
spark.sql(
"""
|select
| ci.*,
| pi.product_name,
| uva.click_product_id
|from user_visit_action uva
|join product_info pi
|on uva.click_product_id = pi.product_id
|join city_info ci
|on uva.city_id = ci.city_id
|""".stripMargin).createOrReplaceTempView("t1")
spark.sql(
"""
|select
| area,
| product_name,
| count(*) count
|from t1
|group by area,product_name
|""".stripMargin).createOrReplaceTempView("t2")
spark.sql(
"""
|select
| area,
| product_name,
| count,
| rank() over(partition by area order by count desc) rk
|from t2
|""".stripMargin).createOrReplaceTempView("t3")
spark.sql(
"""
|select
| area,
| product_name,
| count
|from t3
|where rk<=3
|""".stripMargin).show
spark.close()
}
}
结果
使用UDAF实现城市备注的部分
- 1.创建UDAF
import org.apache.spark.sql.SparkSession
object SqlApp {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "root")
val spark = SparkSession
.builder()
.master("local[*]")
.appName("SqlApp")
.enableHiveSupport()
.getOrCreate()
spark.udf.register("remark",new CityRemarkUDAF)
//去执行sql,从hive查询数据
spark.sql("use spark1602")
spark.sql(
"""
|select
| ci.*,
| pi.product_name,
| uva.click_product_id
|from user_visit_action uva
|join product_info pi
|on uva.click_product_id = pi.product_id
|join city_info ci
|on uva.city_id = ci.city_id
|""".stripMargin).createOrReplaceTempView("t1")
spark.sql(
"""
|select
| area,
| product_name,
| count(*) count,
| remark(city_name) remark
|from t1
|group by area,product_name
|""".stripMargin).createOrReplaceTempView("t2")
spark.sql(
"""
|select
| area,
| product_name,
| count,
| remark,
| rank() over(partition by area order by count desc) rk
|from t2
|""".stripMargin).createOrReplaceTempView("t3")
spark.sql(
"""
|select
| area,
| product_name,
| count,
| remark
|from t3
|where rk<=3
|""".stripMargin).show(false)
spark.close()
}
}
实现UDAF
import java.text.DecimalFormat
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, MapType, StringType, StructField, StructType}
class CityRemarkUDAF extends UserDefinedAggregateFunction{
//输入的数据类型 "北京","天津","String"
override def inputSchema: StructType = StructType(Array(StructField("city",StringType)))
//缓冲类型,每个地区的每个商品,缓冲所有城市的点击量
//1.Map(北京->1000 天津->1000 石家庄->500) 用Map来存
//2.总的点击量
override def bufferSchema: StructType =
StructType(Array(StructField("map",MapType(StringType,LongType)),StructField("total",LongType)))
//最终聚合的结果类型 北京21.2% 天津13.2% 其他65.6% String
override def dataType: DataType = StringType
//确定性
override def deterministic: Boolean = true
//对缓冲区做初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String,Long]()
buffer(1) = 0L
}
//分区内聚合
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
input match {
//这个是remark(city_name)
case Row(cityName:String) =>
//1.总数的点击量 + 1
buffer(1) = buffer.getLong(1) + 1L
//2.给这个城市的点击量 + 1 => 找到缓冲的map,取出来这个城市原来的点击 +1 再赋值过去
val map = buffer.getMap[String,Long](0)
buffer(0) = map + (cityName -> (map.getOrElse(cityName,0L)+1L))
case _ =>
}
}
//分区间的聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getMap[String,Long](0)
val map2 = buffer2.getMap[String,Long](0)
val total1 = buffer1.getLong(1)
val total2 = buffer2.getLong(1)
// 1.总数的聚合
buffer1(1) = total1 + total2
//2.map的聚合
buffer1(0) = map1.foldLeft(map2){
case (map,(cityName,count))=>
map + (cityName-> (map.getOrElse(cityName,0L) +count))
}
}
//返回最后的聚合结果
override def evaluate(buffer: Row): String = {
//北京21.2%,天津13.2%,其他65.6%
val cityAndCount = buffer.getMap[String,Long](0)
val total = buffer.getLong(1)
val cityCountTop2 = cityAndCount.toList.sortBy(-_._2).take(2)
var cityRemarks = cityCountTop2.map {
case (cityName, count) => CityRemark(cityName, count.toDouble/total)
}
// CityRemark("其他",1-cityRemarks.foldLeft(0D)(_+_.cityRadio))
cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_-_.cityRadio))
cityRemarks.mkString(",")
}
}
case class CityRemark(cityName:String,cityRadio:Double){
val f = new DecimalFormat("0.00%")
//北京21.2%,天津13.2%,其他65.6%
override def toString:String = s"$cityName:${f.format(cityRadio.abs)}"
}
结果
把数据写到mysql中
代码实现
import java.util.Properties
import org.apache.spark.sql.SparkSession
object SqlApp {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "root")
val spark = SparkSession
.builder()
.master("local[*]")
.appName("SqlApp")
.enableHiveSupport()
.getOrCreate()
spark.udf.register("remark",new CityRemarkUDAF)
//去执行sql,从hive查询数据
spark.sql("use spark1602")
spark.sql(
"""
|select
| ci.*,
| pi.product_name,
| uva.click_product_id
|from user_visit_action uva
|join product_info pi
|on uva.click_product_id = pi.product_id
|join city_info ci
|on uva.city_id = ci.city_id
|""".stripMargin).createOrReplaceTempView("t1")
spark.sql(
"""
|select
| area,
| product_name,
| count(*) count,
| remark(city_name) remark
|from t1
|group by area,product_name
|""".stripMargin).createOrReplaceTempView("t2")
spark.sql(
"""
|select
| area,
| product_name,
| count,
| remark,
| rank() over(partition by area order by count desc) rk
|from t2
|""".stripMargin).createOrReplaceTempView("t3")
val url = "jdbc:mysql://hadoop102:3306/sparksql?useUnicode=true&characterEncoding=UTF-8"
val user = "root"
val pw = "root"
val props = new Properties()
props.put("user",user)
props.put("password",pw)
spark.sql(
"""
|select
| area,
| product_name,
| count,
| remark
|from t3
|where rk<=3
|""".stripMargin)
.coalesce(1)
.write
.mode("overwrite")
.jdbc(url,"sql1602",props)
//把结果写到mysql中
spark.close()
}
}
成功