Spark-Sql快速入门系列(2) | 自定义SparkSQL函数

一.数据源

{"name":"lisi","age":20}
{"name":"ww","age":10}
{"name":"zl","age":15}
{"name":"zy","age":30}

二.自定义 UDF 函数

import org.apache.spark.sql.SparkSession

object UDFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .master("local[*]")
      .appName("UDFDemo")
      .getOrCreate()

    val df = spark.read.json("D:\\idea\\spark-sql\\input\\user.json")
    //toUpperCase将字符串转换成大写
    // 注册一个 udf 函数: toUpper是函数名, 第二个参数是函数的具体实现
    spark.udf.register("toUpper",(s: String) => s.toUpperCase)
    df.createOrReplaceTempView("user")
    spark.sql("select toUpper(name),age from user").show()

    spark.close()

  }
}

结果
在这里插入图片描述

三.用户自定义聚合函数

  • 强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数
  • 继承UserDefinedAggregateFunction

sum()聚合

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}

import scala.collection.immutable.Nil

object UDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("RDD2DF")
      .getOrCreate()
    import  spark.implicits._

    val df = spark.read.json("D:\\idea\\spark-sql\\input\\user.json")
    df.createOrReplaceTempView("user")
    // 注册聚合函数
    spark.udf.register("mySum",new MySum)
    spark.sql("select mySum(age) from user").show
    spark.close()
  }
}
class MySum extends UserDefinedAggregateFunction {

  //用来定义输入的数据类型  10.1   12.2
  override def inputSchema: StructType = StructType(StructField("ele",DoubleType)::Nil)

  //缓冲区的类型
  override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::Nil)

  //最终聚合结果的类型
  override def dataType: DataType = DoubleType

  //相同的输入是否返回相同的输出
  override def deterministic: Boolean = true

  //对缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //在缓冲集合中初始化和
    buffer(0) = 0D     //等价于 buffer.update(0,0D)
  }

  //分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // input是指的使用聚合函数的时候,缓过来的参数封装到Row中
    if (!input.isNullAt(0)){  //考虑到传字段可能是null
      val v = input.getAs[Double](0) //等价于 getDouble(0)
      buffer(0) = buffer.getDouble(0) + v
    }
  }

  //分区间的聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // 把buffer1 和 buffer2的缓冲聚合在一起,再把值写回到buffer1中
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
  }

  //返回最终的输出值
  override def evaluate(buffer: Row): Any = buffer.getDouble(0)
}

avg()聚合

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}

import scala.collection.immutable.Nil

object UDAFDemo1 {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("RDD2DF")
      .getOrCreate()
    import  spark.implicits._

    val df = spark.read.json("D:\\idea\\spark-sql\\input\\user.json")
    df.createOrReplaceTempView("user")
    // 注册聚合函数
    spark.udf.register("myAvg",new MyAvg)
    spark.sql("select myAvg(age) from user").show
    spark.close()
  }
}
class MyAvg extends UserDefinedAggregateFunction {

  //用来定义输入的数据类型  10.1   12.2
  override def inputSchema: StructType = StructType(StructField("ele",DoubleType)::Nil)

  //缓冲区的类型
  override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::(StructField("count",LongType)
    ::Nil))

  //最终聚合结果的类型
  override def dataType: DataType = DoubleType

  //相同的输入是否返回相同的输出
  override def deterministic: Boolean = true

  //对缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //在缓冲集合中初始化和
    buffer(0) = 0D //等价于 buffer.update(0,0D)
    buffer(1) = 0L
  }

  //分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // input是指的使用聚合函数的时候,缓过来的参数封装到Row中
    if (!input.isNullAt(0)){  //考虑到传字段可能是null
      val v = input.getAs[Double](0) //等价于 getDouble(0)
      buffer(0) = buffer.getDouble(0) + v
      buffer(1) = buffer.getLong(1) +  1L
    }
  }

  //分区间的聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // 把buffer1 和 buffer2的缓冲聚合在一起,再把值写回到buffer1中
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  //返回最终的输出值
  override def evaluate(buffer: Row): Any = buffer.getDouble(0) /buffer.getLong(1)
}

四.自定义强类型聚合函数(了解)

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator


case class Dog(name:String,age:Int)
case class AgeAvg(sum:Int,count:Int){
  def avg =sum.toDouble/count
}

object UDAFDemo2 {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("RDD2DF")
      .getOrCreate()
    import  spark.implicits._

    val ds =List(Dog("大黄",6),Dog("小黄",2),Dog("中黄",4)).toDS()
    //强类型的使用方式
    val avg = new MyAvg2().toColumn.name("avg")
    val result = ds.select(avg)
    result.show()
    spark.close()
  }
}
class MyAvg2 extends Aggregator[Dog,AgeAvg,Double] {
  //对缓冲区进行初始化
  override def zero: AgeAvg = AgeAvg(0,0)

  //聚合(分区内聚合)
  override def reduce(b: AgeAvg, a: Dog): AgeAvg = a match {
      //如果是dog对象,则把年龄相加,个数加1
    case Dog(name,age) =>AgeAvg(b.sum +age,b.count + 1)
      //如果是null,则原封不动返回
    case _ => b
  }

  //分区间的聚合
  override def merge(b1: AgeAvg, b2: AgeAvg): AgeAvg = {
    AgeAvg(b1.sum+b2.sum,b1.count+b2.count)
  }

  //返回最终的值
  override def finish(reduction: AgeAvg): Double = reduction.avg

  //对缓冲区进行编码
  override def bufferEncoder: Encoder[AgeAvg] = Encoders.product //如果是样例类,就直接返回这个编码器

  //对返回值进行编码
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}


猜你喜欢

转载自blog.csdn.net/qq_46548855/article/details/108266351