sparkSQL---自定义函数(UDF,UDTF,UDAF)

自定义函数被称为(UDF)
UDF分为三种:

UDF :输入一行,返回一个结果 ;一对一;比如定义一个函数,功能是输入一个IP地址,返回一个对应的省份
UDTF:输入一行,返回多行(hive);一对多;sparkSQL中没有UDTF,spark中用flatMap即可实现该功能
UDAF:输入多行,返回一行;aggregate(聚合),count,sum这些是spark自带的聚合函数,但是复杂的业务,要自己定义

下面来讲解一下UDF和UDAF的使用:

UDF

案例:根据IP地址计算出归属地

自定义一个函数,传入一个IP地址,返回一个对应的省份,然后将这个函数进行注册,就可以在SQL语句中使用我们自定义的函数了。

package XXX

import cn.edu360.sparkIpTest.TestIp
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

/**
  * Create by 。。。
  *
  * join的代价太昂贵,而且非常慢,解决思路是将IP规则的表缓存起来(广播变量)
  */
object IpLocationSQL2 {
  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder()
      .appName("IpLocationSQL")
      .master("local[4]")
      .getOrCreate()

    //将ip.txt读取到HDFS中
    import spark.implicits._
    val rulesLines: Dataset[String] = spark.read.textFile("数据源地址")

    //整理ip规则数据
    //这里是在Executor中执行的,每个Executor只计算部分的IP规则数据
    val ipRulesDataset: Dataset[(Long, Long, String)] = rulesLines.map(line => {
      val fields = line.split("[|]")
      val startNum = fields(2).toLong
      val endNum = fields(3).toLong
      val province = fields(6)
      (startNum, endNum, province)
    })

    val ipRulesInDriver: Array[(Long, Long, String)] = ipRulesDataset.collect()

    //将IP规则广播出去
    val ipRulesBroadcastRef: Broadcast[Array[(Long, Long, String)]] = spark.sparkContext.broadcast(ipRulesInDriver)


    //接下来开始读取访问日志数据
    val accessLines: Dataset[String] = spark.read.textFile("数据源地址")

    //整理日志文件的数据,取出ip,转换成十进制,与IP规则进行比较(采用二分法)
    val ips: Dataset[Long] = accessLines.map(line => {
      val fields = line.split("[|]")
      val ip = fields(1)
      //将ip转换成十进制
      val ipNum = TestIp.ip2Long(ip)
      ipNum
    })

    val ipDataFrame: DataFrame = ips.toDF("ipNum")

    //创建视图
    ipDataFrame.createTempView("v_ipNum")

    //定义一个自定义函数(UDF),并注册
    //该函数的功能是(输入一个IP地址对应的十进制,返回一个省份名称)
    spark.udf.register("ip2Province",(ipNum:Long) => {
      //查找IP规则(事先已经广播了,已经在Executor中了)
      //函数的逻辑是在Executor中执行的,使用广播变量的引用,就可以获得IP规则对应的数据
      val ipRulesInExecutor: Array[(Long, Long, String)] = ipRulesBroadcastRef.value
      //根据IP地址对应的十进制查找省份名称
      val index: Int = TestIp.binarySearch(ipRulesInExecutor,ipNum)
      var province = "未知"
      if (index != -1){
        province = ipRulesInExecutor(index)._3
      }
      province
    })

    //执行SQL
    val result: DataFrame = spark.sql("SELECT ip2Province(ipNum) province,COUNT(*) counts FROM v_ipNum GROUP BY province ORDER BY counts DESC")

    result.show()

    //释放资源
    spark.stop()


  }

}

UDAF(自定义聚合函数)

在Spark中,自定义聚合函数要继承UserDefinedAggregateFunction这个抽象类,重写里面的方法。
先来看一下这个类的源码:

abstract class UserDefinedAggregateFunction extends Serializable {

  /**
   * A `StructType` represents data types of input arguments of this aggregate function.
   * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
   * with type of `DoubleType` and `LongType`, the returned `StructType` will look like
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * The name of a field of this `StructType` is only used to identify the corresponding
   * input argument. Users can choose names to identify the input arguments.
   *
   * @since 1.5.0
   */
  def inputSchema: StructType

  /**
   * A `StructType` represents data types of values in the aggregation buffer.
   * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
   * (i.e. two intermediate values) with type of `DoubleType` and `LongType`,
   * the returned `StructType` will look like
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * The name of a field of this `StructType` is only used to identify the corresponding
   * buffer value. Users can choose names to identify the input arguments.
   *
   * @since 1.5.0
   */
  def bufferSchema: StructType

  /**
   * The `DataType` of the returned value of this [[UserDefinedAggregateFunction]].
   *
   * @since 1.5.0
   */
  def dataType: DataType

  /**
   * Returns true iff this function is deterministic, i.e. given the same input,
   * always return the same output.
   *
   * @since 1.5.0
   */
  def deterministic: Boolean

  /**
   * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
   *
   * The contract should be that applying the merge function on two initial buffers should just
   * return the initial buffer itself, i.e.
   * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.
   *
   * @since 1.5.0
   */
  def initialize(buffer: MutableAggregationBuffer): Unit

  /**
   * Updates the given aggregation buffer `buffer` with new input data from `input`.
   *
   * This is called once per input row.
   *
   * @since 1.5.0
   */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit

  /**
   * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.
   *
   * This is called when we merge two partially aggregated data together.
   *
   * @since 1.5.0
   */
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit

  /**
   * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
   * aggregation buffer.
   *
   * @since 1.5.0
   */
  def evaluate(buffer: Row): Any

可以看出,继承这个类之后,要重写里面的八个方法。
每个方法代表的含义是:

inputSchema:输入数据的类型
bufferSchema:产生中间结果的数据类型
dataType:最终返回的结果类型
deterministic:确保一致性(输入什么类型的数据就返回什么类型的数据),一般用true
initialize:指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate:计算最终的结果

案例:求几何平均数

package XXX

import java.lang

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

/**
  * Create by 。。。
  *
  */
object UdafTest {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder()
      .appName("IpLocationSQL")
      .master("local[4]")       //设置本地模式运行
      .getOrCreate()

    //使用自定义聚合函数
    val geoMean = new GeoMean

	//测试数据,创建一个Dataset
    val range: Dataset[lang.Long] = spark.range(1,11)

    //==========使用SQL方式===========

    //将range这个Dataset注册成视图
//    range.createTempView("v_range")
//    //注册我们自定义的聚合函数
//    spark.udf.register("gm",geoMean)
//    //书写SQL
//    val result: DataFrame = spark.sql("SELECT gm(id) result FROM v_range")


    //===========使用DSL方式============

    import spark.implicits._
    val result: DataFrame = range.groupBy().agg(geoMean($"id").as("result"))

    //展示结果
    result.show()

    spark.stop()

  }    
}

class GeoMean extends UserDefinedAggregateFunction{
  //输入数据的类型
  override def inputSchema: StructType = StructType(List(
    StructField("value",DoubleType)
  ))

  //产生中间结果的数据类型
  override def bufferSchema: StructType = StructType(List(
    //相乘之后返回的积
    StructField("product",DoubleType),
    //参与计算数字的个数
    StructField("counts",LongType)
  ))

  //最终返回的结果类型
  override def dataType: DataType = DoubleType

  //确保一致性,一般用true
  override def deterministic: Boolean = true

  //指定初始值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //相乘的初始值
    buffer(0) = 1.0
    //参与运算数字的个数的初始值
    buffer(1) = 0L
  }

  //每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //每有一个数字参与运算就进行相乘(取出上一次的中间结果,再乘上这次的输入)
    buffer(0) = buffer.getDouble(0) * input.getDouble(0)
    //参与运算数据的个数也有更新
    buffer(1) = buffer.getLong(1) + 1L
  }

  //全局聚合(将每个分区的结果进行聚合)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //每个分区计算的结果进行相乘
    buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
    //每个分区计算的参与运算的个数进行相加
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  //计算最终的结果
  override def evaluate(buffer: Row): Any = {
    //分别取出buffer中的相乘结果,以及数据的个数,然后求出数据个数的倒数,计算几何平均数
    math.pow(buffer.getDouble(0),1.toDouble / buffer.getLong(1))
  }
}

猜你喜欢

转载自blog.csdn.net/weixin_43866709/article/details/88914871