Custom function
Types of
- UDF:一进一出
- UDAF:多进一出
UDF
Process
- 1. Custom udf function/class (class should pay attention to serialization)
- 2. Register spark.udf.register("name", custom function/custom class_)
- 3. Call the query method
Custom udf function and call
import org.apache.spark.sql.SparkSession
import org.junit.Test
/**
* @ClassName: MyUDFdemo
* @Description: 将员工中id不满8位的补齐
* @Author: kele
* @Date: 2021/2/1 20:56
**/
/**
* 1、自定义udf函数/类(类要注意需要序列化)
* 2、注册spark.udf.register("名称",自定义的函数/自定义的类 _)
* 3、调用查询方法
*/
class MyUDFdemo extends Serializable{
@Test
def emp_info={
val spark = SparkSession.builder().master("local[4]").appName("UDFdemo").getOrCreate()
import spark.implicits._ //rddtoDF的隐式转换
val rdd1 = spark.sparkContext.parallelize(List(
("00123","zhangsan"),
("256","lisi"),
("0135","wangwu"),
("000368","qianqi"),
("00378","zhaoliu")
))
val df = rdd1.toDF("id","name")
/**
* 方式一:通过sql的方式查询 自定义函数
*
*/
// df.createOrReplaceTempView("user")
// spark.udf.register("fullId",fullUserId)
// spark.sql("""select fullId(id) from user """).show()
/**
* 自定义类,需要序列化
*
*/
df.createOrReplaceTempView("user")
spark.udf.register("fullId2",fullUserIdclass _)
spark.sql("""select fullId2(id) from user """).show()
/**
* 方式二:selectExpr的方式查找
*/
df.selectExpr("fullId2(id) id").show()
}
//自定义udf函数
val fullUserId = (id : String)=>{
s"${"0" *(8-id.length)}${id}"
}
//自定义udf类
def fullUserIdclass(id:String) ={
s"${"0" *(8-id.length)}${id}"
}
}
OUT OF
UDAF weak type implementation
Overall process
- 1. Inherit UserDefinedAggregateFunction (no generics)
- 2. Rewriting method
-1. Specify the type with the statistical list
-2. Specify the type of the intermediate variable
-3. Specify the return type of the function
-4. Set the stability
-5. Initialize the value of the intermediate variable
-6. Find one The calculation process in the task
-7, the calculation process between the partitions
-8, the return value of the function - 3. Register spark.udf.register and bind a name to it
Custom UDAF weak type
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{
MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{
DataType, DoubleType, IntegerType, StructType}
/**
* @ClassName: MyUDAF
* @Description:
* @Author: kele
* @Date: 2021/2/1 16:03
**/
class MyUDAF extends UserDefinedAggregateFunction{
/**
* 指定待统计的数据类型
* @return 返回值类型是StructType类型,
*/
override def inputSchema: StructType = new StructType().add("age",IntegerType)
/**
* 这里是求平均值,需要sum,和num,因此需要两个中间变量
* 指定中间变量的类型,数据进入是是一个个进
* @return
*/
override def bufferSchema: StructType = new StructType().add("sum",IntegerType)
.add("num",IntegerType)
/**
* 函数的返回类型
* @return
*/
override def dataType: DataType = DoubleType
/**
* 稳定性,同一组数据输入是否返回相同的值
* @return
*/
override def deterministic: Boolean = true
/**
* 初始化buffer的值
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0)
buffer.update(1,0)
}
/**
* 在一个task中的计算过程
* sum将age不断累加
* count+1
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0,buffer.getAs[Int](0)+input.getAs[Int](0))
buffer.update(1,buffer.getAs[Int](1)+1)
}
/**
* 分区间的计算方式
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getAs[Int](0)+buffer2.getAs[Int](0))
buffer1.update(1,buffer1.getAs[Int](1)+buffer2.getAs[Int](1))
}
/**
* 返回值
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = buffer.getAs[Int](0).toDouble/buffer.getAs[Int](1)
}
Call process
/**
* 调用弱类型
*/
@Test
def avg_Age={
val spark = SparkSession.builder()
.master("local[4]")
.appName("avg_age")
.getOrCreate()
val rdd = spark.sparkContext.parallelize(List(
("zhangsan",20,"开发部"),
("wanwu",25,"产品部"),
("aa",26,"开发部"),
("lisi",40,"开发部"),
("bb",30,"产品部"),
("cc",28,"产品部")
))
import spark.implicits._
val df = rdd.toDF("name","age","dept")
df.createOrReplaceTempView("user")
spark.udf.register("myavg",new MyUDAF)
spark.sql(
"""
|select myavg(age) from user group by dept
""".stripMargin).show()
}
UDAF strong typing process
-
1. Custom class inherits Aggregator [type of statistical column, type of intermediate variable, type of output result]
-
2. Rewriting method
- 1. Initialize intermediate variables
- 2. The statistical process in each task
- 3. The calculation process between partitions
- 4. Calculate the final result and return
- 5. Encoding the type of intermediate variable, personally think it is to ensure the transmission of intermediate data.
Note that the parent class of the sample class is product
-
3. Register spark.udf.register (function name, udaf (custom udaf object))
- import org.apache.spark.sql.functions._ //必须调用该隐式转换,否则无法导入
Custom strong typing
package com.atguigu.day05
import org.apache.spark.sql.{
Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
/**
* @ClassName: MyUDAF2
* @Description: 强类型自定义类,Aggregator可以自定义泛型[输入类型,中间变量,输出类型]
* @Author: kele
* @Date: 2021/2/1 16:05
**/
/**
* 如果需要多个中间变量,可以考虑使用样例类
*
*/
case class InterVari(var sum:Int,var count:Int)
class MyUDAFStrong extends Aggregator[Int,InterVari,Double]{
/**
* 初始化中间变量
* @return
*/
override def zero: InterVari = InterVari(0,0)
/**
* 每一个task中的统计过程
* @param b
* @param a
* @return
*/
override def reduce(b: InterVari, a: Int): InterVari = {
b.sum = b.sum+a
b.count = b.count+1
b
}
/**
* 分区间计算过程
* @param b1
* @param b2
* @return
*/
override def merge(b1: InterVari, b2: InterVari): InterVari = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count
b1
}
/**
* 最终结果返回
* @param reduction
* @return
*/
override def finish(reduction: InterVari): Double = reduction.sum.toDouble/reduction.count
/**
* 编码中间变量的类型,个人认为是为了保证中间数据传输
* @return 样例类的父类是product
*/
override def bufferEncoder: Encoder[InterVari] = Encoders.product
/**
* 编码结果值的类型,个人认为是为了保证中间数据传输
* @return
*/
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
Call process
/**
* 使用弱类型
*/
@Test
def avg_Age={
val spark = SparkSession.builder()
.master("local[4]")
.appName("avg_age")
.getOrCreate()
val rdd = spark.sparkContext.parallelize(List(
("zhangsan",20,"开发部"),
("wanwu",25,"产品部"),
("aa",26,"开发部"),
("lisi",40,"开发部"),
("bb",30,"产品部"),
("cc",28,"产品部")
))
import spark.implicits._
val df = rdd.toDF("name","age","dept")
df.createOrReplaceTempView("user")
spark.udf.register("myavg",new MyUDAF)
spark.sql(
"""
|select myavg(age) from user group by dept
""".stripMargin).show()
}