Spark中UDF、UDAF、UDTF的使用

一、UDF

测试数据 user.json:

{
    
    "id": 1001, "name": "foo", "sex": "man", "age": 20}
{
    
    "id": 1002, "name": "bar", "sex": "man", "age": 24}
{
    
    "id": 1003, "name": "baz", "sex": "man", "age": 18}
{
    
    "id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{
    
    "id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{
    
    "id": 1006, "name": "baz3", "sex": "woman", "age": 20}

1、通过匿名函数的方式注册自定义算子

user.json中的woman、man改为female、male

	//创建sparksession
    val spark = SparkSession
      .builder
      .master("local[*]")
      .appName("SparkUDF")
      .enableHiveSupport()      //启用hive
      .getOrCreate()
    
    //sparksession直接读取csv,可设置分隔符delimitor.
    val userDF = spark.read.json("in/user.json")
    val sc: SparkContext = spark.sparkContext
    
    //将DataFrame注册成视图,然后即可使用hql访问
    userDF.createOrReplaceTempView("userDF")

    //通过匿名函数的方式注册自定义算子:将woman和man分别转换成female和male
    spark.udf.register("Sex",(sex:String)=>{
    
    
      var result="unknown"
      if (sex=="woman"){
    
    
        result="female"
      }else if(sex=="man"){
    
    
        result="male"
      }
      result
    })
    spark.sql("select Sex(sex) from userDF").show()

运行结果如下:

+------------+
|UDF:Sex(sex)|
+------------+
|        male|
|        male|
|        male|
|      female|
|      female|
|      female|
+------------+

2、通过实名函数的方式注册自定义算子

    //创建sparksession
    val spark = SparkSession
      .builder
      .master("local[*]")
      .appName("SparkUDF")
      .enableHiveSupport()      //启用hive
      .getOrCreate()

    //sparksession直接读取csv,可设置分隔符delimitor.
    val userDF = spark.read.json("in/user.json")
    val sc: SparkContext = spark.sparkContext

    //将DataFrame注册成视图,然后即可使用hql访问
    userDF.createOrReplaceTempView("userDF")
    /*
    通过实名函数的方式注册自定义算子
    Scala中方法和函数是两个不同的概念,方法无法作为参数进行传递,
    也无法赋值给变量,但是函数是可以的。在Scala中,利用下划线可以将方法转换成函数:
    */
    spark.udf.register("sex",Sex _)
    spark.sql("select Sex(sex) as A from userDF").show()
  }

  //将woman和man分别转换成female和male
  def Sex(sex:String): String ={
    
    
    var result="unknown"
    if (sex=="woman"){
    
    
      result="female"
    }else if(sex=="man"){
    
    
      result="male"
    }
    result
  }

运行结果如下:

+------+
|     A|
+------+
|  male|
|  male|
|  male|
|female|
|female|
|female|
+------+

二、UDAF

1、UDAF简介

先解释一下什么是UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。
即输入多行数据,产生一个输出

关于UDAF的一个误区

我们可能下意识的认为UDAF是需要和group by一起使用的,实际上UDAF可以跟group by一起使用,也可以不跟group by一起使用,这个其实比较好理解,联想到mysql中的max、min等函数,可以:

		select max(foo) from foobar group by bar;

表示根据bar字段分组,然后求每个分组的最大值,这时候的分组有很多个,使用这个函数对每个分组进行处理,也可以:

		select max(foo) from foobar;

这种情况可以将整张表看做是一个分组,然后在这个分组(实际上就是一整张表)中求最大值。所以聚合函数实际上是对分组做处理,而不关心分组中记录的具体数量。

2、UDAF使用

2.1 继承UserDefinedAggregateFunction

使用UserDefinedAggregateFunction的套路:

  1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现

  2. 在spark中注册UDAF,为其绑定一个名字

  3. 然后就可以在sql语句中使用上面绑定的名字调用

下面写一个计算平均值的 UDAF 例子

首先定义一个MyAgeAvgFunction类继承UserDefinedAggregateFunction

class MyAgeAvgFunction extends UserDefinedAggregateFunction {
    
    
  //聚合函数的输入数据结构
  override def inputSchema: StructType = {
    
    
    new StructType().add("age",LongType)
//    StructType(StructField("age",LongType)::Nil)		//作用同上
  }

  //缓存里面的数据结构
  override def bufferSchema: StructType = {
    
    
    new StructType().add("sum",LongType).add("count",LongType)
//    StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)		//作用同上
  }

  //聚合函数返回值的数据结构
  override def dataType: DataType = {
    
    
    DoubleType
  }

  //聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
  override def deterministic: Boolean = true

  //初始化缓冲区
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    
    
    buffer(0)=0L
    buffer(1)=0L
  }

  //给聚合函数传入一条新数据进行处理
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    
    
    buffer(0)=buffer.getLong(0)+input.getLong(0)
    buffer(1)=buffer.getLong(1)+1
  }

  //合并聚合缓冲区
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    
    
    //总年龄数
    buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
    //个数
    buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
  }

  //计算最终结果
  override def evaluate(buffer: Row): Any = {
    
    
    buffer.getLong(0).toDouble/buffer.getLong(1)
  }
}

然后注册并使用它:

    val spark = SparkSession.builder()
      .appName("SparkUDAF")
      .master("local[*]")
      .getOrCreate()
    import spark.implicits._
    val sc: SparkContext = spark.sparkContext
    val df: DataFrame = spark.read.json("in/user.json")

    //创建并注册自定义udaf函数
    val myUdaf = new MyAgeAvgFunction
    spark.udf.register("myAvgAge",myUdaf)

    df.createOrReplaceTempView("userinfo")
    val resultDF: DataFrame = spark
      .sql("select sex,Round(myAvgAge(age),2) as avgage from userinfo group by sex")
    resultDF.show()

数据集user.json:

{
    
    "id": 1001, "name": "foo", "sex": "man", "age": 20}
{
    
    "id": 1002, "name": "bar", "sex": "man", "age": 24}
{
    
    "id": 1003, "name": "baz", "sex": "man", "age": 18}
{
    
    "id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{
    
    "id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{
    
    "id": 1006, "name": "baz3", "sex": "woman", "age": 20}

运行结果如下:

+-----+------+
|  sex|avgage|
+-----+------+
|  man| 20.67|
|woman| 18.67|
+-----+------+

3、继承Aggregator

还有另一种方式就是继承Aggregator这个类,优点是可以带类型
此处省略该方法代码演示,需要的请看:
https://www.cnblogs.com/cc11001100/p/9471859.html

三、UDTF

1、UDTF简介

通过实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDTF来自定义 UDTF 算子,UDTF 是一行输入,多行输出

2、UDTF使用

2.1、继承GenericUDTF

创建MyUDTF类继承GenericUDTF

class MyUDTF extends GenericUDTF{
    
    
  override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
    
    
    if(argOIs.length!=1){
    
    
      throw new UDFArgumentException("有且只能有一个参数传入")
    }
    if (argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
    
    
      throw new UDFArgumentException("参数类型不匹配")
    }
    val fieldNames = new util.ArrayList[String]
    val fieldOIs = new util.ArrayList[ObjectInspector]()
    fieldNames.add("type")
    //这里定义输出列字段类型
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)

    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
  }

  //传入 Hadoop Scala kafka hive hbase Oozie
  //输出    HEAD  type  String
  //Hadoop
  //Scala
  //kafka
  //hive
  //hbase
  //Oozie
  
  override def process(objects: Array[AnyRef]): Unit = {
    
    
    //将字符串切分成单个单词
    val strings: Array[String] = objects(0).toString.split(" ")
    
    for (str <- strings){
    
    
      val tmp = new Array[String](1)
      tmp(0)=str
      forward(tmp)
    }
  }

  override def close(): Unit = {
    
     }
}

然后获取hive支持并使用它:

    val spark = SparkSession.builder()
      .appName("SparkUDTFDemo")
      .master("local[*]")
      .enableHiveSupport()
      .getOrCreate()
    import spark.implicits._
    val sc: SparkContext = spark.sparkContext
    val lines: RDD[String] = sc.textFile("in/udtf.txt")
    val stuDF: DataFrame = lines.map(_.split("//"))
      .filter(x => x(1).equals("ls"))
      .map(x=> (x(0), x(1), x(2)))
      .toDF("id", "name", "class")

    stuDF.createOrReplaceTempView("student")
    spark.sql("create temporary function MyUDTF as 'shuju.MyUDTF' ")
    val resultDF: DataFrame = spark.sql("select MyUDTF(class) from student")
    resultDF.show()

数据集udtf.txt:

01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop

运行结果如下:

+------+
|  type|
+------+
|Hadoop|
| scala|
| kafka|
|  hive|
| hbase|
| Oozie|
+------+

猜你喜欢

转载自blog.csdn.net/qq_42578036/article/details/109749253