spark-sql 自定义函数

(1)自定义UDF

object SparkSqlTest {
    def main(args: Array[String]): Unit = {
        //屏蔽多余的日志
        Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN)
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
        Logger.getLogger("org.project-spark").setLevel(Level.WARN)
        //构建编程入口
        val conf: SparkConf = new SparkConf()
        conf.setAppName("SparkSqlTest")
            .setMaster("local[2]")
        val spark: SparkSession = SparkSession.builder().config(conf)
            .getOrCreate()

        //创建sqlcontext对象
        val sqlContext: SQLContext = spark.sqlContext

        /**
          * 注册定义的UDF:
          * 这里的泛型[Int,String]
          * 第一个是返回值类型,后面可以是一个或者多个,是方法参数类型
          */
        sqlContext.udf.register[Int,String]("strLen",strLen)
        val sql=
            """
              |select strLen("zhangsan")
            """.stripMargin
        spark.sql(sql).show()
    }
    //自定义UDF方法
    def strLen(str:String):Integer={
        str.length
    }
}

(2) 自定义UDAF

这里举的例子是实现一个count:
自定义UDAF类:

    class MyCountUDAF extends UserDefinedAggregateFunction{
    //该UDAF输入的数据类型
    override def inputSchema: StructType = {
        StructType(List(
            StructField("age",DataTypes.IntegerType)
        ))
    }

    //在该UDAF中聚合的数据类型
    override def bufferSchema: StructType = {
        StructType(List(
            StructField("age",DataTypes.IntegerType)
        ))
    }
    //该UDAF输出的数据类型
    override def dataType: DataType = DataTypes.IntegerType

    //确定性判断,通常特定输入和输出的类型一致
    override def deterministic: Boolean = true

    //buffer:计算过程中临时的存储了聚合结果的Buffer
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer.update(0,0)
    }

    /**
      * 分区内的数据聚合合并
      * @param buffer:就是我们在initialize方法中声明初始化的临时缓冲区
      * @param input:聚合操作新传入的值
      */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        val oldValue=buffer.getInt(0)
        buffer.update(0,oldValue+1)
    }
    /**
      * 分区间的聚合
      * @param buffer1:分区一聚合的临时结果
      * @param buffer2;分区二聚合的临时结果
      */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        val p1=buffer1.getInt(0)
        val p2=buffer2.getInt(0)
        buffer1.update(0,p1+p2)
    }

    //该聚合函数最终输出的值
    override def evaluate(buffer: Row): Any = {
        buffer.get(0)
    }
}

调用:

object SparkSqlTest {
    def main(args: Array[String]): Unit = {
        //屏蔽多余的日志
        Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN)
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
        Logger.getLogger("org.project-spark").setLevel(Level.WARN)
        //构建编程入口
        val conf: SparkConf = new SparkConf()
        conf.setAppName("SparkSqlTest")
            .setMaster("local[2]")
            .set("spark.serializer","org.apache.spark.serializer.KryoSerializer")
            .registerKryoClasses(Array(classOf[Student]))
        val spark: SparkSession = SparkSession.builder().config(conf)
            .getOrCreate()

        //创建sqlcontext对象
        val sqlContext: SQLContext = spark.sqlContext

        //注册UDAF
        sqlContext.udf.register("myCount",new MyCountUDAF())

        val stuList = List(
            new Student("委xx", 18),
            new Student("吴xx", 18),
            new Student("戚xx", 18),
            new Student("王xx", 19),
            new Student("薛xx", 19)
        )
        import spark.implicits._
        val stuDS: Dataset[Student] = sqlContext.createDataset(stuList)
        stuDS.createTempView("student")
        val sql=
            """
              |select myCount(1) counts
              |from student
              |group by age
              |order by counts
            """.stripMargin
        spark.sql(sql).show()
    }

}
case class Student(name:String,age:Int)

猜你喜欢

转载自blog.51cto.com/14048416/2339276