Spark分组TopN(SQL风格SDL风格)另附:RDD操作

 第一种:SQL风格

package sql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}

/**
  *
  * @ClassName: SparkSQL
  * @Description: SQL风格
  * @Author: xuezhouyi
  * @Version: V1.0
  *
  **/
object SparkSQL {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)

    val spark: SparkSession = SparkSession.builder()
      .appName("XavierXue")
      .master("local[*]")
      .getOrCreate()

    val factRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/fact_tran/FACT_TRAN.txt")
    val prodRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/dim_prod/DIM_PROD.txt")
    val locaRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/dim_location/DIM_LOCATION.txt")

    import spark.implicits._

    factRDD.foreach(println)

    /**
      * 20200101|XA1001|2020-01-01 10:10:01|WY|AliPay|1|999
      * 20200103|XA1001|2020-01-03 10:10:03|YT|AliPay|1|777
      * 20200101|XA1001|2020-01-01 10:10:02|LH|WeChat|2|999
      * 20200201|XA1001|2020-02-01 10:10:01|WY|AliPay|3|666
      * 20200101|XA1001|2020-01-01 10:10:03|YT|AliPay|3|999
      * 20200201|XA1001|2020-02-01 10:10:02|LH|AliPay|2|666
      * 20200102|XA1001|2020-01-02 10:10:01|WY|WeChat|3|888
      * 20200201|XA1001|2020-02-01 10:10:03|YT|WeChat|1|666
      * 20200102|XA1001|2020-01-02 10:10:02|LH|AliPay|3|888
      * 20200301|XA1001|2020-03-01 10:10:01|WY|AliPay|3|888
      * 20200102|XA1001|2020-01-02 10:10:03|YT|AliPay|2|888
      * 20200301|XA1001|2020-03-01 10:10:02|LH|AliPay|1|888
      * 20200301|XA1001|2020-03-01 10:10:03|YT|WeChat|1|888
      * 20200103|XA1001|2020-01-03 10:10:01|WY|AliPay|3|777
      * 20200103|XA1001|2020-01-03 10:10:02|LH|WeChat|3|777
      */
    prodRDD.foreach(println)

    /**
      * 1|家电
      * 2|数码
      * 3|手机
      */
    locaRDD.foreach(println)
    /**
      * WY|未央区|XA|西安市
      * YT|雁塔区|XA|西安市
      * LH|莲湖区|XA|西安市
      */

    val factDF: DataFrame = factRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      (strings(0).toString, strings(1).toString, strings(2).toString, strings(3).toString, strings(4).toString, strings(5).toString, strings(6).trim.toInt)
    }).toDF("DATA_DATE", "ACCT", "TRAN_TIME", "LOCATION", "CHANNEL", "PROD_TYPE", "TRAN_AMT")
    factDF.show()
    factDF.createOrReplaceTempView("fact")
    /**
      * +---------+------+-------------------+--------+-------+---------+--------+
      * |DATA_DATE|  ACCT|          TRAN_TIME|LOCATION|CHANNEL|PROD_TYPE|TRAN_AMT|
      * +---------+------+-------------------+--------+-------+---------+--------+
      * | 20200101|XA1001|2020-01-01 10:10:01|      WY| AliPay|        1|     999|
      * | 20200101|XA1001|2020-01-01 10:10:02|      LH| WeChat|        2|     999|
      * | 20200101|XA1001|2020-01-01 10:10:03|      YT| AliPay|        3|     999|
      * | 20200102|XA1001|2020-01-02 10:10:01|      WY| WeChat|        3|     888|
      * | 20200102|XA1001|2020-01-02 10:10:02|      LH| AliPay|        3|     888|
      * | 20200102|XA1001|2020-01-02 10:10:03|      YT| AliPay|        2|     888|
      * | 20200103|XA1001|2020-01-03 10:10:01|      WY| AliPay|        3|     777|
      * | 20200103|XA1001|2020-01-03 10:10:02|      LH| WeChat|        3|     777|
      * | 20200103|XA1001|2020-01-03 10:10:03|      YT| AliPay|        1|     777|
      * | 20200201|XA1001|2020-02-01 10:10:01|      WY| AliPay|        3|     666|
      * | 20200201|XA1001|2020-02-01 10:10:02|      LH| AliPay|        2|     666|
      * | 20200201|XA1001|2020-02-01 10:10:03|      YT| WeChat|        1|     666|
      * | 20200301|XA1001|2020-03-01 10:10:01|      WY| AliPay|        3|     888|
      * | 20200301|XA1001|2020-03-01 10:10:02|      LH| AliPay|        1|     888|
      * | 20200301|XA1001|2020-03-01 10:10:03|      YT| WeChat|        1|     888|
      * +---------+------+-------------------+--------+-------+---------+--------+
      */

    val prodDF: DataFrame = prodRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      (strings(0).toString, strings(1).toString)
    }).toDF("ID", "NAME")
    prodDF.show()
    prodDF.createOrReplaceTempView("prod")
    /**
      * +---+----+
      * | ID|NAME|
      * +---+----+
      * |  1|  家电|
      * |  2|  数码|
      * |  3|  手机|
      * +---+----+
      */

    val locaDF: DataFrame = locaRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      (strings(0).toString, strings(1).toString, strings(2).toString, strings(3).toString)
    }).toDF("ID1", "NAME1", "ID2", "NAME2")
    locaDF.show()
    locaDF.createOrReplaceTempView("loca")
    /**
      * +---+-----+---+-----+
      * |ID1|NAME1|ID2|NAME2|
      * +---+-----+---+-----+
      * | WY|  未央区| XA|  西安市|
      * | LH|  莲湖区| XA|  西安市|
      * | YT|  雁塔区| XA|  西安市|
      * +---+-----+---+-----+
      */

    var sql: String =
      """
        |select /* +broadcast(p,l) */
        |l.NAME1 as LOCA,p.NAME as PROD,sum(f.TRAN_AMT) as TOTAL
        |from FACT f
        |inner join PROD p on f.PROD_TYPE=p.ID
        |inner join LOCA l on f.LOCATION=l.ID1
        |group by l.NAME1,p.NAME
      """.stripMargin
    spark.sql(sql).show()

    /**
      * +----+----+-----+
      * |LOCA|PROD|TOTAL|
      * +----+----+-----+
      * | 莲湖区|  数码| 1665|
      * | 莲湖区|  家电|  888|
      * | 莲湖区|  手机| 1665|
      * | 未央区|  家电|  999|
      * | 未央区|  手机| 3219|
      * | 雁塔区|  数码|  888|
      * | 雁塔区|  手机|  999|
      * | 雁塔区|  家电| 2331|
      * +----+----+-----+
      */
    spark.sql(sql).createOrReplaceTempView("tmp")

    sql =
      """
        |with t as(
        | select t.*,row_number() over(partition by t.LOCA order by t.TOTAL desc) as RN
        |from tmp t)
        |select t.LOCA,t.PROD,t.TOTAL from t where t.RN=1
      """.stripMargin
    spark.sql(sql).show()

    /**
      * +----+----+-----+
      * |LOCA|PROD|TOTAL|
      * +----+----+-----+
      * | 莲湖区|  数码| 1665|
      * | 未央区|  手机| 3219|
      * | 雁塔区|  家电| 2331|
      * +----+----+-----+
      */

    spark.stop()
  }
}

第二种:DSL风格

package sql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{Window, WindowSpec}
import org.apache.spark.sql.{DataFrame, SparkSession, functions}

/**
  *
  * @ClassName: SparkDF
  * @Description: DSL风格
  * @Author: xuezhouyi
  * @Version: V1.0
  *
  **/
object SparkDF {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)

    val spark: SparkSession = SparkSession.builder()
      .appName("XavierXue")
      .master("local[*]")
      .getOrCreate()

    val factRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/fact_tran/FACT_TRAN.txt")
    val prodRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/dim_prod/DIM_PROD.txt")
    val locaRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/dim_location/DIM_LOCATION.txt")

    import spark.implicits._

    factRDD.foreach(println)

    /**
      * 20200101|XA1001|2020-01-01 10:10:01|WY|AliPay|1|999
      * 20200103|XA1001|2020-01-03 10:10:03|YT|AliPay|1|777
      * 20200101|XA1001|2020-01-01 10:10:02|LH|WeChat|2|999
      * 20200201|XA1001|2020-02-01 10:10:01|WY|AliPay|3|666
      * 20200101|XA1001|2020-01-01 10:10:03|YT|AliPay|3|999
      * 20200201|XA1001|2020-02-01 10:10:02|LH|AliPay|2|666
      * 20200102|XA1001|2020-01-02 10:10:01|WY|WeChat|3|888
      * 20200201|XA1001|2020-02-01 10:10:03|YT|WeChat|1|666
      * 20200102|XA1001|2020-01-02 10:10:02|LH|AliPay|3|888
      * 20200301|XA1001|2020-03-01 10:10:01|WY|AliPay|3|888
      * 20200102|XA1001|2020-01-02 10:10:03|YT|AliPay|2|888
      * 20200301|XA1001|2020-03-01 10:10:02|LH|AliPay|1|888
      * 20200301|XA1001|2020-03-01 10:10:03|YT|WeChat|1|888
      * 20200103|XA1001|2020-01-03 10:10:01|WY|AliPay|3|777
      * 20200103|XA1001|2020-01-03 10:10:02|LH|WeChat|3|777
      */
    prodRDD.foreach(println)

    /**
      * 1|家电
      * 2|数码
      * 3|手机
      */
    locaRDD.foreach(println)
    /**
      * WY|未央区|XA|西安市
      * YT|雁塔区|XA|西安市
      * LH|莲湖区|XA|西安市
      */

    val factDF: DataFrame = factRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      (strings(0).toString, strings(1).toString, strings(2).toString, strings(3).toString, strings(4).toString, strings(5).toString, strings(6).trim.toInt)
    }).toDF("DATA_DATE", "ACCT", "TRAN_TIME", "LOCATION", "CHANNEL", "PROD_TYPE", "TRAN_AMT")
    factDF.show()
    /**
      * +---------+------+-------------------+--------+-------+---------+--------+
      * |DATA_DATE|  ACCT|          TRAN_TIME|LOCATION|CHANNEL|PROD_TYPE|TRAN_AMT|
      * +---------+------+-------------------+--------+-------+---------+--------+
      * | 20200101|XA1001|2020-01-01 10:10:01|      WY| AliPay|        1|     999|
      * | 20200101|XA1001|2020-01-01 10:10:02|      LH| WeChat|        2|     999|
      * | 20200101|XA1001|2020-01-01 10:10:03|      YT| AliPay|        3|     999|
      * | 20200102|XA1001|2020-01-02 10:10:01|      WY| WeChat|        3|     888|
      * | 20200102|XA1001|2020-01-02 10:10:02|      LH| AliPay|        3|     888|
      * | 20200102|XA1001|2020-01-02 10:10:03|      YT| AliPay|        2|     888|
      * | 20200103|XA1001|2020-01-03 10:10:01|      WY| AliPay|        3|     777|
      * | 20200103|XA1001|2020-01-03 10:10:02|      LH| WeChat|        3|     777|
      * | 20200103|XA1001|2020-01-03 10:10:03|      YT| AliPay|        1|     777|
      * | 20200201|XA1001|2020-02-01 10:10:01|      WY| AliPay|        3|     666|
      * | 20200201|XA1001|2020-02-01 10:10:02|      LH| AliPay|        2|     666|
      * | 20200201|XA1001|2020-02-01 10:10:03|      YT| WeChat|        1|     666|
      * | 20200301|XA1001|2020-03-01 10:10:01|      WY| AliPay|        3|     888|
      * | 20200301|XA1001|2020-03-01 10:10:02|      LH| AliPay|        1|     888|
      * | 20200301|XA1001|2020-03-01 10:10:03|      YT| WeChat|        1|     888|
      * +---------+------+-------------------+--------+-------+---------+--------+
      */

    val prodDF: DataFrame = prodRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      (strings(0).toString, strings(1).toString)
    }).toDF("ID", "NAME")
    prodDF.show()
    /**
      * +---+----+
      * | ID|NAME|
      * +---+----+
      * |  1|  家电|
      * |  2|  数码|
      * |  3|  手机|
      * +---+----+
      */

    val locaDF: DataFrame = locaRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      (strings(0).toString, strings(1).toString, strings(2).toString, strings(3).toString)
    }).toDF("ID1", "NAME1", "ID2", "NAME2")
    locaDF.show()

    /**
      * +---+-----+---+-----+
      * |ID1|NAME1|ID2|NAME2|
      * +---+-----+---+-----+
      * | WY|  未央区| XA|  西安市|
      * | LH|  莲湖区| XA|  西安市|
      * | YT|  雁塔区| XA|  西安市|
      * +---+-----+---+-----+
      */

    val result1: DataFrame = factDF.join(prodDF, factDF("PROD_TYPE") === prodDF("ID"), "inner")
      .join(locaDF, factDF("LOCATION") === locaDF("ID1"), "inner")
      .groupBy("NAME1", "NAME")
      .agg(("TRAN_AMT", "sum"))
      .toDF("LOCA", "PROD", "TOTAL")
    result1.show()

    /**
      * +----+----+-----+
      * |LOCA|PROD|TOTAL|
      * +----+----+-----+
      * | 莲湖区|  数码| 1665|
      * | 未央区|  家电|  999|
      * | 未央区|  手机| 3219|
      * | 雁塔区|  数码|  888|
      * | 雁塔区|  手机|  999|
      * | 莲湖区|  手机| 1665|
      * | 雁塔区|  家电| 2331|
      * | 莲湖区|  家电|  888|
      * +----+----+-----+
      */

    val win: WindowSpec = Window.partitionBy($"LOCA").orderBy($"TOTAL".desc)
    val result: DataFrame = result1.withColumn("RN", functions.row_number().over(win))
      .where($"RN" === 1).drop("RN")
    result.show()

    /**
      * +----+----+-----+
      * |LOCA|PROD|TOTAL|
      * +----+----+-----+
      * | 莲湖区|  数码| 1665|
      * | 未央区|  手机| 3219|
      * | 雁塔区|  家电| 2331|
      * +----+----+-----+
      */

    spark.stop()
  }
}

第三种:RDD

package sql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.Partitioner
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession

import scala.collection.mutable

/**
  *
  * @ClassName: SparkRDD
  * @Description: 使用RDD实现分组TopN
  * @Author: xuezhouyi
  * @Version: V1.0
  *
  **/
object SparkRDD {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)

    val spark: SparkSession = SparkSession.builder()
      .appName("XavierXue")
      .master("local[*]")
      .getOrCreate()

    val factRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/fact_tran/FACT_TRAN.txt")
    val prodRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/dim_prod/DIM_PROD.txt")
    val locaRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/dim_location/DIM_LOCATION.txt")

    factRDD.foreach(println)

    /**
      * 20200101|XA1001|2020-01-01 10:10:01|WY|AliPay|1|999
      * 20200103|XA1001|2020-01-03 10:10:03|YT|AliPay|1|777
      * 20200101|XA1001|2020-01-01 10:10:02|LH|WeChat|2|999
      * 20200201|XA1001|2020-02-01 10:10:01|WY|AliPay|3|666
      * 20200101|XA1001|2020-01-01 10:10:03|YT|AliPay|3|999
      * 20200201|XA1001|2020-02-01 10:10:02|LH|AliPay|2|666
      * 20200102|XA1001|2020-01-02 10:10:01|WY|WeChat|3|888
      * 20200201|XA1001|2020-02-01 10:10:03|YT|WeChat|1|666
      * 20200102|XA1001|2020-01-02 10:10:02|LH|AliPay|3|888
      * 20200301|XA1001|2020-03-01 10:10:01|WY|AliPay|3|888
      * 20200102|XA1001|2020-01-02 10:10:03|YT|AliPay|2|888
      * 20200301|XA1001|2020-03-01 10:10:02|LH|AliPay|1|888
      * 20200301|XA1001|2020-03-01 10:10:03|YT|WeChat|1|888
      * 20200103|XA1001|2020-01-03 10:10:01|WY|AliPay|3|777
      * 20200103|XA1001|2020-01-03 10:10:02|LH|WeChat|3|777
      */

    prodRDD.foreach(println)
    /**
      * 1|家电
      * 2|数码
      * 3|手机
      */

    val prodMap: Map[String, String] = prodRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      (strings(0), strings(1))
    }).collect().toMap

    /**
      * 地区广播
      */
    val prodBC: Broadcast[Map[String, String]] = spark.sparkContext.broadcast(prodMap)

    locaRDD.foreach(println)
    /**
      * WY|未央区|XA|西安市
      * YT|雁塔区|XA|西安市
      * LH|莲湖区|XA|西安市
      */

    val locaMap: Map[String, String] = locaRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      (strings(0), strings(1))
    }).collect().toMap

    /**
      * 产品广播
      */
    val locaBC: Broadcast[Map[String, String]] = spark.sparkContext.broadcast(locaMap)

    val value: RDD[((String, String), Int)] = factRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      new fact(strings(3), strings(5), strings(6).trim.toInt)
    }).map(e => {
      val locaStr: String = locaBC.value.get(e.loca).getOrElse("Nil")
      val prodStr: String = prodBC.value.get(e.prod).getOrElse("Nil")
      ((locaStr, prodStr), e.amt)
    }).reduceByKey(_ + _)

    value.foreach(println)

    /**
      * ((WY,1),999)
      * ((YT,1),2331)
      * ((YT,2),888)
      * ((YT,3),999)
      * ((LH,2),1665)
      * ((LH,3),1665)
      * ((WY,3),3219)
      * ((LH,1),888)
      * ----------使用广播进行翻译----------
      * ((莲湖区,家电),888)
      * ((莲湖区,手机),1665)
      * ((莲湖区,数码),1665)
      * ((雁塔区,手机),999)
      * ((未央区,家电),999)
      * ((未央区,手机),3219)
      * ((雁塔区,数码),888)
      * ((雁塔区,家电),2331)
      */

    /**
      * 使用集合中的List排序
      */
    val result1: RDD[(String, List[((String, String), Int)])] = value.groupBy(_._1._1).mapValues(_.toList.sortBy(_._2).reverse.take(1))
    result1.foreach(println)

    /**
      * (莲湖区,List(((莲湖区,数码),1665)))
      * (雁塔区,List(((雁塔区,家电),2331)))
      * (未央区,List(((未央区,手机),3219)))
      */

    /**
      * 使用过滤的方式,将RDD中对应的数据中只有一个数据
      */
    for (e <- locaMap) {
      val filterRDD: RDD[((String, String), Int)] = value.filter(_._1._1 == e._2)
      val result2: Array[((String, String), Int)] = filterRDD.sortBy(_._2, false).take(1)
      println(result2.mkString(","))
    }

    /**
      * ((未央区,手机),3219)
      * ((莲湖区,手机),1665)
      * ((雁塔区,家电),2331)
      */

    /**
      * 自定义分区器
      */
    val locaPts: Array[String] = value.map(_._1._1).distinct().collect()
    val valuePts: RDD[((String, String), Int)] = value.partitionBy(new MyPartitioner(locaPts))
    val result3: Array[((String, String), Int)] = valuePts.mapPartitions(e => {
      e.toList.sortBy(_._2).reverse.take(1).iterator
    }).collect()
    println(result3.mkString(","))

    /**
      * ((莲湖区,数码),1665),((未央区,手机),3219),((雁塔区,家电),2331)
      */

    spark.stop()
  }
}

/**
  * 自定义主表的样例类
  *
  * @param loca
  * @param prod
  * @param amt
  */
case class fact(loca: String, prod: String, amt: Int)

/**
  * 自定义分区
  *
  * @param value
  */
class MyPartitioner(value: Array[String]) extends Partitioner {
  override def numPartitions: Int = value.length

  private val map: mutable.HashMap[String, Int] = new mutable.HashMap[String, Int]()
  var i: Int = 0
  for (e <- value) {
    map.put(e, i)
    i += 1
  }

  override def getPartition(key: Any): Int = {
    val p: String = key.asInstanceOf[(String, String)]._1
    map(p)
  }
}

第四种:RDD(优化)

package sql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.Partitioner
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession

import scala.collection.mutable

/**
  *
  * @ClassName: SparkRDDOpt
  * @Description: 使用RDD实现分组TopN(优化)
  * @Author: xuezhouyi
  * @Version: V1.0
  *
  **/
object SparkRDDOpt {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)

    val spark: SparkSession = SparkSession.builder()
      .appName("XavierXue")
      .master("local[*]")
      .getOrCreate()

    val factRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/fact_tran/FACT_TRAN.txt")
    val prodRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/dim_prod/DIM_PROD.txt")
    val locaRDD: RDD[String] = spark.sparkContext.textFile("hdfs://hadoopha/user/hive/warehouse/ods.db/dim_location/DIM_LOCATION.txt")

    factRDD.foreach(println)

    /**
      * 20200101|XA1001|2020-01-01 10:10:01|WY|AliPay|1|999
      * 20200103|XA1001|2020-01-03 10:10:03|YT|AliPay|1|777
      * 20200101|XA1001|2020-01-01 10:10:02|LH|WeChat|2|999
      * 20200201|XA1001|2020-02-01 10:10:01|WY|AliPay|3|666
      * 20200101|XA1001|2020-01-01 10:10:03|YT|AliPay|3|999
      * 20200201|XA1001|2020-02-01 10:10:02|LH|AliPay|2|666
      * 20200102|XA1001|2020-01-02 10:10:01|WY|WeChat|3|888
      * 20200201|XA1001|2020-02-01 10:10:03|YT|WeChat|1|666
      * 20200102|XA1001|2020-01-02 10:10:02|LH|AliPay|3|888
      * 20200301|XA1001|2020-03-01 10:10:01|WY|AliPay|3|888
      * 20200102|XA1001|2020-01-02 10:10:03|YT|AliPay|2|888
      * 20200301|XA1001|2020-03-01 10:10:02|LH|AliPay|1|888
      * 20200301|XA1001|2020-03-01 10:10:03|YT|WeChat|1|888
      * 20200103|XA1001|2020-01-03 10:10:01|WY|AliPay|3|777
      * 20200103|XA1001|2020-01-03 10:10:02|LH|WeChat|3|777
      */

    prodRDD.foreach(println)
    /**
      * 1|家电
      * 2|数码
      * 3|手机
      */

    val prodMap: Map[String, String] = prodRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      (strings(0), strings(1))
    }).collect().toMap

    /**
      * 地区广播
      */
    val prodBC: Broadcast[Map[String, String]] = spark.sparkContext.broadcast(prodMap)

    locaRDD.foreach(println)
    /**
      * WY|未央区|XA|西安市
      * YT|雁塔区|XA|西安市
      * LH|莲湖区|XA|西安市
      */

    val locaMap: Map[String, String] = locaRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      (strings(0), strings(1))
    }).collect().toMap

    /**
      * 产品广播
      */
    val locaBC: Broadcast[Map[String, String]] = spark.sparkContext.broadcast(locaMap)

    /**
      * 构建RDD
      */
    val value: RDD[((String, String), Int)] = factRDD.map(e => {
      val strings: Array[String] = e.split("\\|")
      new fact(strings(3), strings(5), strings(6).trim.toInt)
    }).map(e => {
      val locaStr: String = locaBC.value(e.loca)
      val prodStr: String = prodBC.value(e.prod)
      ((locaStr, prodStr), e.amt)
    })

    value.foreach(println)

    /**
      * ((WY,1),999)
      * ((YT,1),2331)
      * ((YT,2),888)
      * ((YT,3),999)
      * ((LH,2),1665)
      * ((LH,3),1665)
      * ((WY,3),3219)
      * ((LH,1),888)
      * ----------使用广播进行翻译----------
      * ((未央区,家电),999)
      * ((莲湖区,数码),999)
      * ((雁塔区,手机),999)
      * ((未央区,手机),888)
      * ((莲湖区,手机),888)
      * ((雁塔区,数码),888)
      * ((未央区,手机),777)
      * ((莲湖区,手机),777)
      * ((雁塔区,家电),777)
      * ((未央区,手机),666)
      * ((莲湖区,数码),666)
      * ((雁塔区,家电),666)
      * ((未央区,手机),888)
      * ((莲湖区,家电),888)
      * ((雁塔区,家电),888)
      */

    val keyPts: Array[String] = value.map(_._1._1).distinct().collect()
    val result: Array[((String, String), Int)] = value.reduceByKey(new MyPartitioner(keyPts), _ + _)
      .mapPartitions(e => e.toList.sortBy(_._2).reverse.take(1).iterator)
      .collect()
    println(result.mkString(","))

    /**
      * ((莲湖区,手机),1665),((未央区,手机),3219),((雁塔区,家电),2331)
      */
  }
}

/**
  * 自定义主表的样例类
  *
  * @param loca
  * @param prod
  * @param amt
  */
case class fact(loca: String, prod: String, amt: Int)

/**
  * 自定义分区
  *
  * @param value
  */
class MyPartitioner(value: Array[String]) extends Partitioner {
  override def numPartitions: Int = value.length

  private val map: mutable.HashMap[String, Int] = new mutable.HashMap[String, Int]()
  var i: Int = 0
  for (e <- value) {
    map.put(e, i)
    i += 1
  }

  override def getPartition(key: Any): Int = {
    val p: String = key.asInstanceOf[(String, String)]._1
    map(p)
  }
}
发布了54 篇原创文章 · 获赞 19 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/DataIntel_XiAn/article/details/104115023
今日推荐