Spark SQL(二十二)用户自定义的UDF、UDAF函数

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Suubyy/article/details/82763512
  1. 用户自定义的UDF

    1. 定义:UDF(User-Defined-Function),也就是最基本的函数,它提供了SQL中对字段转换的功能,不涉及聚合操作。例如将日期类型转换成字符串类型,格式化字段。

    2. 用法

      object UDFTest {
        case class Person(name: String, age: Int)
        def main(args: Array[String]): Unit = {
          //常见SparkSession
          val sparkSession: SparkSession = SparkSession.builder().appName("DataFrameTest").master("local[2]").getOrCreate()
          //根据文件获取RDD
          val personRDD: RDD[String] = sparkSession.sparkContext.textFile("C:\\Users\\39402\\Desktop\\person.txt")
      
          /**
            * 注册一个udf函数,
            * toString:为自定义函数的引用名,
            * (str: String) => str + "我是UDF自定义函数":这个是自定义的函数体,它是一个匿名函数
            */
          sparkSession.udf.register("toString", (str: String) => str + "我是UDF自定义函数")
      
          import sparkSession.implicits._
          //引入隐式转换
          //利用反射将RDD转换成DataFrame
          val personDF: DataFrame = personRDD.map(_.split(",")).map(line => Person(line(0), line(1).toInt)).toDF()
      
          //将DataFrame注册成一张表
          personDF.createOrReplaceTempView("person")
      
          //利用Spark的SQL来查询数据,其中toString就是我们自定义的UDF函数
          sparkSession.sql("select toString(name),age from person").show()
        }
      }
      
      
  2. 用户自定义的UDAF

    1. 定义: UDAF函数是用户自定义的聚合函数,为Spark SQL提供对数据集的聚合功能,类似于max()、min()、count()等功能,只不过自定义的功能是根据具体的业务功能来确定的。因为DataFrame是弱类型的,DataSet是强类型,所以自定义的UDAF也提供了两种实现,一个是弱类型的一个是强类型的。

    2. 弱类型用法,需要继承UserDefindAggregateFunction,实现它的方法

      package com.lyz.sql.udf
      
      import org.apache.spark.sql.Row
      import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
      import org.apache.spark.sql.types._
      
      object MyCustomUDAF extends UserDefinedAggregateFunction {
        //:: Nil 作用就是为StructField常见Array集合,并放入进去
        def inputSchema: StructType = StructType(StructField("age", IntegerType) :: Nil)
      
        //缓存字段类型,也就是每个分区的共享变量
        def bufferSchema: StructType = StructType(StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil)
      
        //UDF输出数据类型
        def dataType: DataType = IntegerType
      
        //输入类型和输出类型是否一致
        def deterministic: Boolean = true
      
      
        //初始化分区中的共享变量
        def initialize(buffer: MutableAggregationBuffer): Unit = {
          //初始化每个分区上的年龄总和为0
          buffer(0) = 0
          
          //初始化每个分区上的人数为0
          buffer(1) = 0
        }
      
        //每个分区中每一条记录,聚合的时候需要调用该方法
        def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        
          //将新输入进来的数据一之前合并的结果做聚合操作,
          //buffer(0)就是上边定义的年龄总和sum,也就是每个分区上的年龄总和
          buffer(0) = buffer.getInt(0) + input.getInt(0)
          
          //buffer(1)就是上边定义的人的个数count,也就是每个分区上的人个数
          buffer(1) = buffer.getInt(1) + 1
        }
      
        //对分区结果进行合并
        def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
          // buffer1(0)就是所有分区的年龄总和
          //buffer1.getInt(0) + buffer2.getInt(0):就是将没分区上的年龄相加
          //下标为0的就是年龄总和
          buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
          
          //buffer(1)就是所有分区的人个数
          //buffer1.getInt(1) + buffer2.getInt(1):就是将每个分区人个数聚合在一起,
          //下标为1就是人的个数
          buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1)
        }
      
        //最终结算结果
        def evaluate(buffer: Row): Any = {
          buffer.getInt(0) / buffer.getInt(1)
        }
      }
      
      
      package com.lyz.sql.udf
      
      import com.lyz.sql.dataframe.DataFrameTest.Person
      import org.apache.spark.rdd.RDD
      import org.apache.spark.sql.{DataFrame, SparkSession}
      
      object MyCustomUDAFMain {
        def main(args: Array[String]): Unit = {
      
          val sparkSession: SparkSession = SparkSession.builder().appName("DataFrameTest").master("local[2]").getOrCreate()
          
          //根据文件获取RDD
          val personRDD: RDD[String] = sparkSession.sparkContext.textFile("C:\\Users\\39402\\Desktop\\person.txt")
          import sparkSession.implicits._
          
          //引入隐式转换
          //利用反射将RDD转换成DataFrame
          val personDF: DataFrame = personRDD.map(_.split(",")).map(line => Person(line(0), line(1).toInt)).toDF()
          sparkSession.udf.register("myCustomUDAF", MyCustomUDAF)
          personDF.createOrReplaceTempView("person")
      
          /**
            * 输出结果为:15
            */
          sparkSession.sql("select myCustomUDAF(age) from person").show()
        }
      }
      
      
    3. 强类型用法,需要继承Aggregate,实现它的方法。既然是强类型,那么其中肯定涉及到对象的存在

      package com.lyz.sql.udf
      
      import org.apache.spark.sql.{Encoder, Encoders}
      import org.apache.spark.sql.expressions.Aggregator
      
      //输入
      case class Person(name: String, age: Int)
      
      //缓存变量,也就是逻辑介质,
      case class Avg(sum: Int, count: Int)
      
      object MyCutomUDAFStrong extends Aggregator[Person, Avg, Int] {
      
        //初始化缓存变量
        def zero: Avg = Avg(0, 0)
      
        /**
          * 每个分区计算各自的结果
          *
          * @param b :聚合后的缓存变量
          * @param a :新输入的数据
          * @return b:聚合后的缓存变量
          */
        def reduce(b: Avg, a: Person): Avg = {
          b.sum += a.age
          b.count += 1
          b
        }
      
        //合并每个分区的结果
        def merge(b1: Avg, b2: Avg): Avg = {
          b1.sum += b2.sum
          b1.count += b2.count
          b1
        }
      
        //最后完成平均值的计算
        def finish(reduction: Avg): Int = {
          reduction.sum / reduction.count
        }
      
        //Encoders.product:是对scala元组和case类型转换的编码器
        def bufferEncoder: Encoder[Avg] = Encoders.product
      
        //设定输出值的编码器
        def outputEncoder: Encoder[Int] = Encoders.scalaInt
      }
      
      
      package com.lyz.sql.udf
      
      
      import org.apache.spark.rdd.RDD
      import org.apache.spark.sql.{Dataset, SparkSession, TypedColumn}
      
      object MyCustomStrongMain {
      
        def main(args: Array[String]): Unit = {
          val sparkSession: SparkSession = SparkSession.builder().appName("DataFrameTest").master("local[2]").getOrCreate()
      
          //根据文件获取RDD
          val personRDD: RDD[String] = sparkSession.sparkContext.textFile("C:\\Users\\39402\\Desktop\\person.txt")
      
          import sparkSession.implicits._ //引入隐式转换
      
          //里用RDD生成Dataset
          val personDS: Dataset[Person] = personRDD.map(_.split(",")).map(line => Person(line(0), line(1).toInt)).toDS()
      
          //将这个函数转成TypedColumn,并且提供一个别名
          val avgAge: TypedColumn[Person, Int] = MyCustomUDAFStrong.toColumn.name("ageAvg")
      
          personDS.select(avgAge).show()
      
        }
      }
      
      

猜你喜欢

转载自blog.csdn.net/Suubyy/article/details/82763512