Funciones definidas por el usuario de Spark sql

En la ventana Shell puede spark.udfhaber funciones definidas por el usuario.

1. Función UDF definida por el usuario

scala> val df =
spark.read.json("examples/src/main/resources/people.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]

scala> df.show()
+----+-------+
| age| name|
+----+-------+
|null|Michael|
| 30| Andy|
| 19| Justin|
+----+-------+

scala> spark.udf.register("addName", (x:String)=> "Name:"+x)
res5: org.apache.spark.sql.expressions.UserDefinedFunction =
UserDefinedFunction(,StringType,Some(List(StringType)))

scala> df.createOrReplaceTempView("people")

scala> spark.sql("Select addName(name), age from people").show()
+-----------------+----+
|UDF:addName(name)| age|
+-----------------+----+
| Name:Michael|null|
| Name:Andy| 30|
| Name:Justin| 19|
+-----------------+----+

Dos funciones agregadas definidas por el usuario

Tanto el Dataset fuertemente tipado como el DataFrame débilmente tipado proporcionan funciones agregadas relacionadas, como count (), countDistinct (), avg (), max (), min (). Además, los usuarios pueden configurar sus propias funciones agregadas personalizadas. Débil definida por el usuario agregado tipo de función: 通过继承 UserDefinedAggregateFunction 来实现用户自定义聚合函数. A continuación se muestra una función agregada personalizada para encontrar el salario promedio.

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
object MyAverage extends UserDefinedAggregateFunction {
    
    
// 聚合函数输入参数的数据类型
def inputSchema: StructType = StructType(StructField("inputColumn",
LongType) :: Nil)
// 聚合缓冲区中值得数据类型
def bufferSchema: StructType = {
    
    
StructType(StructField("sum", LongType) :: StructField("count",
LongType) :: Nil)
}
// 返回值的数据类型
def dataType: DataType = DoubleType
// 对于相同的输入是否一直返回相同的输出。
def deterministic: Boolean = true
// 初始化
def initialize(buffer: MutableAggregationBuffer): Unit = {
    
    
// 存工资的总额
buffer(0) = 0L
// 存工资的个数
buffer(1) = 0L
}
// 相同 Execute 间的数据合并。
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    
    
if (!input.isNullAt(0)) {
    
    
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
}
// 不同 Execute 间的数据合并
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    
    
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble /
buffer.getLong(1)
}
// 注册函数
spark.udf.register("myAverage", MyAverage)
val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM
employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+

Función agregada definida por el usuario de tipo fuerte: aprobar 继承 Aggregator 来实现强类型自定义聚合函数, también buscando salario promedio

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession
// 既然是强类型,可能有 case 类
case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)
object MyAverage extends Aggregator[Employee, Average, Double] {
    
    
// 定义一个数据结构,保存工资总数和工资总个数,初始都为 0
def zero: Average = Average(0L, 0L)
// Combine two values to produce a new value. For performance, the
function may modify `buffer`
// and return it instead of constructing a new object
def reduce(buffer: Average, employee: Employee): Average = {
    
    
buffer.sum += employee.salary
buffer.count += 1
buffer
}
// 聚合不同 execute 的结果
def merge(b1: Average, b2: Average): Average = {
    
    
b1.sum += b2.sum
b1.count += b2.count
b1
}
// 计算输出
def finish(reduction: Average): Double = reduction.sum.toDouble /
reduction.count
// 设定之间值类型的编码器,要转换成 case 类
// Encoders.product 是进行 scala 元组和 case 类转换的编码器
def bufferEncoder: Encoder[Average] = Encoders.product
// 设定最终输出值的编码器
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
import spark.implicits._
val ds =
spark.read.json("examples/src/main/resources/employees.json").as[Empl
oyee]
ds.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
// Convert the function to a `TypedColumn` and give it a name
val averageSalary = MyAverage.toColumn.name("average_salary")
val result = ds.select(averageSalary)
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+

Supongo que te gusta

Origin blog.csdn.net/weixin_43520450/article/details/108585667
Recomendado
Clasificación