基于Spark下的自定义分区Demo

版权声明:原创作品转载必须标明出处,谢谢配合! https://blog.csdn.net/qq_38704184/article/details/86316972

本demo所需的数据源:

链接: https://pan.baidu.com/s/1VEluh5B3HnodZFyoOZ9Zg 提取码: enmq 

import java.net.URL

import org.apache.spark.{Partitioner, SparkConf, SparkContext}

import scala.collection.mutable

object UrlPartition {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("UrlCountPartition").setMaster("local[2]")
    val sc = new SparkContext(conf)

    /**
      * 将数据切分,元组中放的是(Url,1)
      */
    val rdd1 = sc.textFile("E:\\sfy.log").
      map(line =>{
        val f = line.split("\t")
        (f(1),1)
      })

    /**
      * 统计相同的Url出现多少次
      */
    val rdd2 = rdd1.reduceByKey(_+_)
    /**
      * 将上一步的元素组合成我们想要的格式
      *
      * cache会将数据缓存到内存中,cache是一个Transformation,lazy
      */
    val rdd3 = rdd2.map(t=>{
      val url = t._1
      val host = new URL(url).getHost
      (host,(url,t._2))
    }).cache()
    /**
      * 进行自定义的分区
      */
    val ints = rdd3.map(_._1).distinct().collect()

    val hostPartitioner = new HostPartitioner(ints)

    /*源码
    def partitionBy(partitioner : org.apache.spark.Partitioner) :
    org.apache.spark.rdd.RDD[scala.Tuple2[K, V]] = {
      /* compiled code */ }
    */
    val rdd4 = rdd3.partitionBy(hostPartitioner).mapPartitions(t=>{
      t.toList.sortBy(_._2._2).reverse.take(2).iterator
    })
    rdd4.saveAsTextFile("E:\\outfile1")
    sc.stop()
  }
}
/*源码abstract class Partitioner() extends scala.AnyRef with scala.Serializable {
  def numPartitions : scala.Int
  def getPartition(key : scala.Any) : scala.Int
}
object Partitioner extends scala.AnyRef with scala.Serializable {
  def defaultPartitioner(rdd : org.apache.spark.rdd.RDD[_], others : org.apache.spark.rdd.RDD[_]*) : org.apache.spark.Partitioner = { /* compiled code */ }
}*/
class HostPartitioner(ints:Array[String])extends Partitioner{
  /**
    * (_,0)(_,1)(_,2)
    */
  val partMap = new mutable.HashMap[String,Int]()
  var count = 0
  for (i <- ints){
    partMap += (i -> count)
    count += 1
  }
  override def numPartitions: Int = ints.length

  override def getPartition(key: Any): Int = {
    partMap.getOrElse(key.toString,0)
  }
}

猜你喜欢

转载自blog.csdn.net/qq_38704184/article/details/86316972