SparkSQL 自定义函数UDF与UDAF

自定义函数分类

UDF 输入一行,输出一行

UDAF 输入多行,输出一行

UDTF 输入一样,输出多行

UDF
//导包
import org.apache.spark.sql.SparkSession
//编写代码
// 1.实例SparkSession
    val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()
// 2.根据sparkSession获取SparkContext
    val sc = spark.sparkContext
//3.读取数据并输出
    val datas = spark.read.textFile("./data/udf/udf.txt")
//4.数据展示
    datas.show()
//5.编写UDF将小写变成大写
    spark.udf.register("smallToBig", (str: String) => str.toUpperCase())
//6.将RDD转换为DataFrame
    val dataFrame = datas.toDF()
//7.注册表
    dataFrame.createOrReplaceTempView("word")
//8.使用自定义函数查询 并输出
    spark.sql("select value, smallToBig(value) from word").show()
UDAF

继承UserDefinedAggregateFunction方法重写说明

InputSchema: 输入数据的类型

bufferSchema: 产生中间结果的数据类型

dataType:最终返回的结果类型

dataeministic: 确保一致性,一般用true

initialize: 指定初始值

update:每有一条数据参与运算就更新一下中间结果(update相当于每一个分区中的运算)

merge:全局聚合(将每个分区的结果进行聚合)

evaluate: 计算最终的结果

//导包
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
//编写自定义UDAF
class MyUDAF extends UserDefinedAggregateFunction {
// 输入的数据类型的schema
    override def inputSchema: StructType = {
      StructType(StructField("input", LongType) :: Nil)
    }
//缓冲去数据类型schema 就是转换字后的数据schema
    override def bufferSchema: StructType = {
      StructType(StructField("sum", LongType) :: StructField("total", LongType) :: Nil)
    }
// 返回值数据类型
    override def dataType: DataType = {
      DoubleType
    }
// 确定是否相同的输入会有相同的输出
    override def deterministic: Boolean = true
// 初始化内部数据结构
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = 0L
      buffer(1) = 0L
    }
// 更新数据内部结构,区内计算
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// 所有的金额
      buffer(0) = buffer.getLong(0) + input.getLong(0)
// 一共多少条数据
      buffer(1) = buffer.getLong(1) + 1
    }
// 来字不同分区数据进行合并,全局合并
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
      buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    }
// 计算输出数据值
    override def evaluate(buffer: Row): Any = {
      buffer.getLong(0).toDouble / buffer.getLong(1)
    }
  }
//编写测试代码
//1.实例SparkSession
    val spark = SparkSession.builder().master("local[*]").appName("sql").getOrCreate()
//2.根据SparkSession获取SparkContext 获取上下文对象
    val sc = spark.sparkContext
//3.使用SparkContext 读取数据
    val dataRDD = spark.read.json("./data/udf/udaf.json")
//4.注册表
    dataRDD.createOrReplaceTempView("word")
//5.注册 UDAF 函数
    spark.udf.register("myavg", new MyUDAF)
//6.使用自定义UDAF函数
    spark.sql("select myavg(salary) from word").show()
发布了88 篇原创文章 · 获赞 99 · 访问量 21万+

猜你喜欢

转载自blog.csdn.net/qq_43791724/article/details/105468182