Spark 自定义函数(udf,udaf)

Spark 版本 2.3

文中测试数据(json)

{"name":"lillcol", "age":24,"ip":"192.168.0.8"}
{"name":"adson", "age":100,"ip":"192.168.255.1"}
{"name":"wuli", "age":39,"ip":"192.143.255.1"}
{"name":"gu", "age":20,"ip":"192.168.255.1"}
{"name":"ason", "age":15,"ip":"243.168.255.9"}
{"name":"tianba", "age":1,"ip":"108.168.255.1"}
{"name":"clearlove", "age":25,"ip":"222.168.255.110"}
{"name":"clearlove", "age":30,"ip":"222.168.255.110"}

用户自定义udf

自定义udf的方式有两种

  1. SQLContext.udf.register()
  2. 创建UserDefinedFunction

这两种个方式 使用范围不一样

package com.test.spark

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Dataset, Row, SparkSession}

/**
  * @author Administrator
  *         2019/7/22-14:04
  *
  */
object TestUdf {

  val spark = SparkSession
    .builder()
    .appName("TestCreateDataset")
    .config("spark.some.config.option", "some-value")
    .master("local")
    .enableHiveSupport()
    .getOrCreate()
  val sQLContext = spark.sqlContext

  import spark.implicits._


  def main(args: Array[String]): Unit = {
    testudf
  }

  def testudf() = {
    val iptoLong: UserDefinedFunction = getIpToLong()
    val ds: Dataset[Row] = spark.read.json("D:\\DATA-LG\\PUBLIC\\TYGQ\\INF\\testJson")
    ds.createOrReplaceTempView("table1")
    sQLContext.udf.register("addName", sqlUdf(_: String)) //addName 只能在SQL里面用  不能在DSL 里面用
    //1.SQL
    sQLContext.sql("select *,addName(name) as nameAddName  from table1")
      .show()
    //2.DSL
    val addName: UserDefinedFunction = udf((str: String) => ("ip: " + str))
    ds.select($"*", addName($"ip").as("ipAddName"))
      .show()

    //如果自定义函数相对复杂,可以将它分离出去 如iptoLong
    ds.select($"*", iptoLong($"ip").as("iptoLong"))
      .show()
  }

  def sqlUdf(name: String): String = {
    "name:" + name
  }

  /**
    * 用户自定义 UDF 函数
    *
    * @return
    */
  def getIpToLong(): UserDefinedFunction = {
    val ipToLong: UserDefinedFunction = udf((ip: String) => {
      val arr: Array[String] = ip.replace(" ", "").replace("\"", "").split("\\.")
      var result: Long = 0
      var ipl: Long = 0
      if (arr.length == 4) {
        for (i <- 0 to 3) {
          ipl = arr(i).toLong
          result |= ipl << ((3 - i) << 3)
        }
      } else {
        result = -1
      }
      result
    })
    ipToLong
  }


}

输出结果
+---+---------------+---------+--------------+
|age|             ip|     name|   nameAddName|
+---+---------------+---------+--------------+
| 24|    192.168.0.8|  lillcol|  name:lillcol|
|100|  192.168.255.1|    adson|    name:adson|
| 39|  192.143.255.1|     wuli|     name:wuli|
| 20|  192.168.255.1|       gu|       name:gu|
| 15|  243.168.255.9|     ason|     name:ason|
|  1|  108.168.255.1|   tianba|   name:tianba|
| 25|222.168.255.110|clearlove|name:clearlove|
| 30|222.168.255.110|clearlove|name:clearlove|
+---+---------------+---------+--------------+

+---+---------------+---------+-------------------+
|age|             ip|     name|          ipAddName|
+---+---------------+---------+-------------------+
| 24|    192.168.0.8|  lillcol|    ip: 192.168.0.8|
|100|  192.168.255.1|    adson|  ip: 192.168.255.1|
| 39|  192.143.255.1|     wuli|  ip: 192.143.255.1|
| 20|  192.168.255.1|       gu|  ip: 192.168.255.1|
| 15|  243.168.255.9|     ason|  ip: 243.168.255.9|
|  1|  108.168.255.1|   tianba|  ip: 108.168.255.1|
| 25|222.168.255.110|clearlove|ip: 222.168.255.110|
| 30|222.168.255.110|clearlove|ip: 222.168.255.110|
+---+---------------+---------+-------------------+

+---+---------------+---------+----------+
|age|             ip|     name|  iptoLong|
+---+---------------+---------+----------+
| 24|    192.168.0.8|  lillcol|3232235528|
|100|  192.168.255.1|    adson|3232300801|
| 39|  192.143.255.1|     wuli|3230662401|
| 20|  192.168.255.1|       gu|3232300801|
| 15|  243.168.255.9|     ason|4087938825|
|  1|  108.168.255.1|   tianba|1823014657|
| 25|222.168.255.110|clearlove|3735617390|
| 30|222.168.255.110|clearlove|3735617390|
+---+---------------+---------+----------+

用户自定义 UDAF 函数(即聚合函数)

弱类型用户自定义聚合函数

通过继承UserDefinedAggregateFunction

package com.test.spark

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row, SparkSession}

/**
  * @author lillcol
  *         2019/7/22-15:09
  *         弱类型用户自定义聚合函数
  */
object TestUDAF extends UserDefinedAggregateFunction {
  // 聚合函数输入参数的数据类型
  // :: 用于的是向队列的头部追加数据,产生新的列表,Nil 是一个空的 List,定义为 List[Nothing]
  override def inputSchema: StructType = StructType(StructField("age", IntegerType) :: Nil)

  //等效于
  //  override def inputSchema: StructType=new StructType() .add("age", IntegerType).add("name", StringType)

  // 聚合缓冲区中值的数据类型
  override def bufferSchema: StructType = {
    StructType(StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil)
  }

  // UserDefinedAggregateFunction返回值的数据类型。
  override def dataType: DataType = DoubleType

  // 如果这个函数是确定的,即给定相同的输入,总是返回相同的输出。
  override def deterministic: Boolean = true

  //  初始化给定的聚合缓冲区,即聚合缓冲区的零值。
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // sum,  总的年龄
    buffer(0) = 0
    // count, 人数
    buffer(1) = 0
  }

  //  使用来自输入的新输入数据更新给定的聚合缓冲区。
  // 每个输入行调用一次。(同一分区)
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0) + input.getInt(0) //年龄 叠加
    buffer(1) = buffer.getInt(1) + 1 //人数叠加
  }

  //  合并两个聚合缓冲区并将更新后的缓冲区值存储回buffer1。
  // 当我们将两个部分聚合的数据合并在一起时,就会调用这个函数。(多个分区)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0) //年龄 叠加
    buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1) //人数叠加
  }

  override def evaluate(buffer: Row): Any = {
    buffer.getInt(0).toDouble / buffer.getInt(1)
  }

  val spark = SparkSession
    .builder()
    .appName("Spark SQL basic example")
    // .config("spark.some.config.option", "some-value")
    .master("local[*]") // 本地测试
    .getOrCreate()

  import spark.implicits._

  def main(args: Array[String]): Unit = {
    spark.udf.register("myAvg", TestUDAF)
    val ds: Dataset[Row] = spark.read.json("D:\\DATA-LG\\PUBLIC\\TYGQ\\INF\\testJson")
    ds.createOrReplaceTempView("table1")
    //SQL
    spark.sql("select myAvg(age) as avgAge from table1")
      .show()

    //DSL
    val myavg = TestUDAF
    ds.select(TestUDAF($"age").as("avgAge"))
      .show()
  }
}

输出结果:
+------+
|avgAge|
+------+
| 31.75|
+------+

+------+
|avgAge|
+------+
| 31.75|
+------+

强类型用户自定义聚合函数

通过继承Aggregator(是org.apache.spark.sql.expressions 下的 不要引错包了)

package com.test.spark

import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions._

/**
  * @author Administrator
  *         2019/7/22-16:07
  *
  */
// 既然是强类型,可能有 case 类
case class Person(name: String, age: Double, ip: String)

case class Average(var sum: Double, var count: Double)

object MyAverage extends Aggregator[Person, Average, Double] {
  //  此聚合的值为零。应该满足任意b + 0 = b的性质。
  //  定义一个数据结构,保存工资总数和工资总个数,初始都为0
  override def zero: Average = {
    Average(0, 0)
  }

  //  将两个值组合起来生成一个新值。为了提高性能,函数可以修改b并返回它,而不是为b构造新的对象。
  //  相同 Execute 间的数据合并(同一分区)
  override def reduce(b: Average, a: Person): Average = {
    b.sum += a.age
    b.count += 1
    b
  }

  // 合并两个中间值。
  // 聚合不同 Execute 的结果(不同分区)
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }

  // 计算最终结果
  override def finish(reduction: Average): Double = {
    reduction.sum.toInt / reduction.count
  }

  //  为中间值类型指定“编码器”。
  override def bufferEncoder: Encoder[Average] = Encoders.product

  //  为最终输出值类型指定“编码器”。
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

  val spark = SparkSession
    .builder()
    .appName("Spark SQL basic example")
    // .config("spark.some.config.option", "some-value")
    .master("local[*]") // 本地测试
    .getOrCreate()

  import spark.implicits._

  def main(args: Array[String]): Unit = {
    val ds: Dataset[Person] = spark.read.json("D:\\DATA-LG\\PUBLIC\\TYGQ\\INF\\testJson").as[Person]
    ds.show()

    val avgAge = MyAverage.toColumn/*.name("avgAge")*///指定该列的别名为avgAge
    ds.select(avgAge)//执行avgAge.as("columnName") 汇报org.apache.spark.sql.AnalysisException错误  别名只能在上面指定(目前测试是这样)
      .show()
  }
}

输出结果:
+---+---------------+---------+
|age|             ip|     name|
+---+---------------+---------+
| 24|    192.168.0.8|  lillcol|
|100|  192.168.255.1|    adson|
| 39|  192.143.255.1|     wuli|
| 20|  192.168.255.1|       gu|
| 15|  243.168.255.9|     ason|
|  1|  108.168.255.1|   tianba|
| 25|222.168.255.110|clearlove|
| 30|222.168.255.110|clearlove|
+---+---------------+---------+

+------+
|avgAge|
+------+
| 31.75|
+------+

本文为原创文章,转载请注明出处!!!

猜你喜欢

转载自www.cnblogs.com/lillcol/p/11229044.html