SparkSQL编程之用户自定义函数

IDEA创建SparkSQL程序

IDEA中程序的打包和运行方式都和SparkCore类似,Maven依赖中需要添加新的依赖项

<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-sql_2.11</artifactId>
    <version>2.1.1</version>
</dependency>

程序如下:

import org.apache.spark.sql.SparkSession
import org.apache.spark.{
    
    SparkConf, SparkContext}
import org.slf4j.LoggerFactory
object HelloWorld {
    
    
  def main(args: Array[String]) {
    
    
    //创建SparkConf()并设置App名称
    val spark = SparkSession
      .builder()
      .appName("Spark SQL basic example")
      .config("spark.some.config.option", "some-value")
      .getOrCreate()
    // For implicit conversions like converting RDDs to DataFrames
    import spark.implicits._
    val df = spark.read.json("data/people.json")
    // Displays the content of the DataFrame to stdout
    df.show()
    df.filter($"age" > 21).show()
    df.createOrReplaceTempView("persons")
    spark.sql("SELECT * FROM persons where age > 21").show()
    spark.stop()
  }
}

用户自定义UDF函数

在Shell窗口中可以通过spark.udf功能用户可以自定义函数。

scala> val df = spark.read.json("examples/src/main/resources/people.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]
scala> df.show()
+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+
scala> spark.udf.register("addName", (x:String)=> "Name:"+x)
res5: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))
scala> df.createOrReplaceTempView("people")
scala> spark.sql("Select addName(name), age from people").show()
+-----------------+----+
|UDF:addName(name)| age|
+-----------------+----+
|     Name:Michael|null|
|        Name:Andy|  30|
|      Name:Justin|  19|
+-----------------+----+

用户自定义聚合函数

强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。
弱类型用户自定义聚合函数:通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。下面展示一个求平均工资的自定义聚合函数。

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
object MyAverage extends UserDefinedAggregateFunction {
    
    
// 聚合函数输入参数的数据类型 def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
// 聚合缓冲区中值得数据类型 def bufferSchema: StructType = {
    
    
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}
// 返回值的数据类型 def dataType: DataType = DoubleType
// 对于相同的输入是否一直返回相同的输出。 def deterministic: Boolean = true
// 初始化 def initialize(buffer: MutableAggregationBuffer): Unit = {
    
    
// 存工资的总额 buffer(0) = 0L
// 存工资的个数 buffer(1) = 0L
}
// 相同Execute间的数据合并。 def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    
    
if (!input.isNullAt(0)) {
    
    
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
}
// 不同Execute间的数据合并 def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    
    
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
// 注册函数 spark.udf.register("myAverage", MyAverage)
val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+

强类型用户自定义聚合函数:通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession
// 既然是强类型,可能有case类 case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)
object MyAverage extends Aggregator[Employee, Average, Double] {
    
    
// 定义一个数据结构,保存工资总数和工资总个数,初始都为0
def zero: Average = Average(0L, 0L)
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
def reduce(buffer: Average, employee: Employee): Average = {
    
    
buffer.sum += employee.salary
buffer.count += 1
buffer
}
// 聚合不同execute的结果 def merge(b1: Average, b2: Average): Average = {
    
    
b1.sum += b2.sum
b1.count += b2.count
b1
}
// 计算输出 def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
// 设定之间值类型的编码器,要转换成case类

// Encoders.product是进行scala元组和case类转换的编码器 def bufferEncoder: Encoder[Average] = Encoders.product
// 设定最终输出值的编码器 def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}  
import spark.implicits._
 val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
ds.show() // +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
// Convert the function to a `TypedColumn` and give it a name
val averageSalary = MyAverage.toColumn.name("average_salary")
val result = ds.select(averageSalary)
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+

关注微信公众号
简书:https://www.jianshu.com/u/0278602aea1d
CSDN:https://blog.csdn.net/u012387141

猜你喜欢

转载自blog.csdn.net/u012387141/article/details/105807536