大数据学习之路89-sparkSQL自定义函数计算ip归属地

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_37050372/article/details/82959016

使用sparkSQL当遇到业务逻辑相关的时候,就有可能会搞不定。因为业务l逻辑需要写很多代码,调用很多接口。这个时候sql就搞不定了。那么这个时候我们就会想能不能将业务逻辑嵌入到sql中?

这种就类似于我们在hive中使用过的自定义函数UDF(user define function用户自定义函数)

那么用户自定义函数有几种呢

有三种:

第一种就是UDF  1 - 1 (输入一行得到一个结果)

第二种就是UDTF 1 - N(输入一行得到多个结果,这个在spark中没有,在hive中有)

为什么在spark中没有UDTF呢?因为在spark中一个flatMap就可以搞定了。不需要自定义函数。

类似于这样:

第三种就是UDAF N - 1 (输入N行得到1个结果)

接下来我们写一个效率低下的通过sparkSQL自定义函数计算IP归属地

package com.test.SparkSQL

import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

object SQLIPLocation {
  val ip2Long = (ip:String )=> {
    val fragments = ip.split("[.]")
    var ipNum = 0L
    for(i <- 0 until fragments.length){
      ipNum = fragments(i).toLong | ipNum << 8L
    }
    ipNum
  }
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("SQLIPLocation")
      .master("local[*]").getOrCreate()
    import spark.implicits._
    //加载ip规则
    val rulelines: Dataset[String] = spark.read.textFile("D:/a/ip.txt")
    //对ip规则进行整理
    val rulesDS: Dataset[(Long, Long, String)] = rulelines.map(line => {
      val fields: Array[String] = line.split("[|]")
      val startNum: Long = fields(2).toLong
      val endNum: Long = fields(3).toLong
      val province: String = fields(6)
      (startNum, endNum, province)
    })
    //将DataSet转为DataFrame
    val rulesDF: DataFrame = rulesDS.toDF("start_num","end_num","province")
    rulesDF.createTempView("v_rules")
    //读取访问日志
    val accessDS: Dataset[String] = spark.read.textFile("D:/a/access.log")
    val ipDS: Dataset[String] = accessDS.map(line => {
      val fields: Array[String] = line.split("[|]")
      val ip = fields(1)
      ip
    })
    val ipDF: DataFrame = ipDS.toDF("ip")
    //注册视图
    ipDF.createTempView("v_ip")
    //使用之前注册自定义函数
    spark.udf.register("ip2Long",ip2Long)
    val result: DataFrame = spark.sql("select province,count(*) from v_ip join v_rules on (ip2Long(ip) >= start_num and ip2Long(ip) <= end_num) group by province")
    result.show()

  }
}

运行结果:

使用这样的代码去执行的话是非常慢的。可能要等好半天,这是为什么呢?

因为两张表的join会产生大量的shuffle

如果是两个大表在进行join的时候,shuffle量也会超大。

那么怎么办呢?

对于这两张表来说,ip规则表相对来说要比较小。

我们怎样才能提高性能呢?

如果是在hive中的话我们怎样做呢?我们会将小表先缓存起来,那么怎样才能将小表缓存起来呢?

我们需要将小表放在前面。就是我们将ip规则直接放在map端,这样我们就不需要进行shuffle了。

我们可以将ip规则广播出去,sparksql也支持广播变量

这样有人就说了,那我们还不如之前通过RDD或DataSet来做呢。

但是我们如果使用DataSet的话还要记忆很多算子。使用spark sql的话就简单了。

------------------------------------------------------------------------------------------------------------

我们对上面的程序进行一下优化:

首先我们读取访问日志,对原来的数据进行数据清洗之后保存成parquet的格式,只保留重要的字段。

package com.test.SparkSQL

import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

object AccessLog {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("AccessLog")
      .master("local[*]")
      .getOrCreate()
    import spark.implicits._
    val accessDS: Dataset[String] = spark.read.textFile("D:/a/access.log")
    val ipDS: Dataset[(String, String)] = accessDS.map(line => {
      val fields: Array[String] = line.split("[|]")
      val time = fields(0)
      val ip = fields(1)
      (time, ip)
    })
    val df: DataFrame = ipDS.toDF("time","ip")
    df.write.parquet("C:/Users/11489/Desktop/access_parquet")

  }
}

以下是读取出来的数据:

接着我们加载ip规则,然后将规则广播出去。

然后读取之前清洗好的parquet文件,注册视图,执行sql

接下来我们可以写一个自定义函数,将ip传进去,返回一个省,这样我们就不需要join了,在一张表就搞定了。

那我们不join,规则从哪来啊?其实我们之前将规则广播就相当于将规则在每一台机器上缓存起来了

我们要使用规则的话,只需要在函数中拿到这个广播的引用就好了。

而且在这个函数中我们可以使用二分查找提高查询效率

package com.test.SparkSQL

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

object SQLIPLocation2 {
  val ip2Long = (ip:String )=> {
    val fragments = ip.split("[.]")
    var ipNum = 0L
    for(i <- 0 until fragments.length){
      ipNum = fragments(i).toLong | ipNum << 8L
    }
    ipNum
  }

  def binarySearch(lines:Array[(Long,Long,String)],ip:Long):Int = {
    var low = 0
    var high = lines.length - 1
    while (low <= high) {
      val middle = (low + high) / 2
      if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2))
        return middle
      if (ip < lines(middle)._1)
        high = middle - 1
      else {
        low = middle + 1
      }

    }
    -1
  }

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("SQLIPLocation2")
      .master("local[*]")
      .getOrCreate()
    import spark.implicits._
    val rulelines: Dataset[String] = spark.read.textFile("D:/a/ip.txt")
    val rule = rulelines.map(line => {
      val fields: Array[String] = line.split("[|]")
      val startNum: Long = fields(2).toLong
      val endNum: Long = fields(3).toLong
      val province: String = fields(6)
      (startNum, endNum, province)
    })

    //收集规则数据(Driver端)
    val rules: Array[(Long, Long, String)] = rule.collect()
    //广播
     val broadcast: Broadcast[Array[(Long, Long, String)]] = spark.sparkContext.broadcast(rules)
    //读取已经处理好的parquet文件
    val accesslog: DataFrame = spark.read.parquet("C:/Users/11489/Desktop/access_parquet")
    //注册视图
    accesslog.createTempView("v_access")
    spark.udf.register("ip2Province",(ip:String) => {
      val ru: Array[(Long, Long, String)] = broadcast.value
      val ipNum = ip2Long(ip)
      val index = binarySearch(ru,ipNum)
      ru(index)._3
    })
    //执行sql
   spark.sql("select ip2Province(ip),count(*) from v_access group by ip2Province(ip) ").show()
  }
}

运行结果:

猜你喜欢

转载自blog.csdn.net/qq_37050372/article/details/82959016
今日推荐