User-defined aggregate function
Strongly typed Dataset and weakly typed DataFrame provides relevant aggregate function, such as COUNT (), CountDistinct (), AVG (), max (), min (). In addition, users can set their own custom aggregation function.
Weak user-defined aggregate function type
By inheritance achieved UserDefinedAggregateFunction user-defined aggregate function. The following shows the average age of a custom request aggregate functions.
First, create a custom aggregate functions in the class, it will inherit the abstract class UserDefinedAggregateFunction, and implement the abstract method which:
package sparksql.udf import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType} class MyAvg extends UserDefinedAggregateFunction{ // input data structure the override DEF inputSchema: StructType = { new new StructType () the Add (. " Age " , LongType) } // buffer data structure of the override DEF bufferSchema: StructType = { new new StructType () .add("sum",LongType).add("count",LongType) } // output data type the override DEF dataType: the DataType = DoubleType // is stable, the same input always gives the same output as the override DEF DETERMINISTIC: Boolean = to true /** * Buffer is initialized with the following requirements: * The contract should be that applying the merge function on two initial buffers should just * return the initial buffer itself, i.e. * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer **/ override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0l; buffer(1) = 0l; } // When the data reaches, how buffer update the override DEF Update (Buffer: MutableAggregationBuffer, the INPUT: Row): Unit = { buffer(0) = buffer.getLong(0)+input.getLong(0) buffer(1) = buffer.getLong(1)+1l } // how to merge two buffers // Merges TWO aggregation keyword buffers and Stores at The Updated Buffer values `buffer1 the Back to the override DEF Merge (buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0)+buffer2.getLong(0) buffer1(1) = buffer1.getLong(1)+buffer2.getLong(1) } /* Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given aggregation buffer. */ override def evaluate(buffer: Row): Any = { buffer.getLong(0).toDouble/buffer.getLong(1) } }
Then create an instance of an abstract class and register the instance, using the abstract class in sql statement, to test:
package sparksql.udf import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} object Demo1 { def main(args: Array[String]): Unit = { // create SparkConf () and set the App Name Val conf = new new SparkConf (). SetAppName ( " sparlsql " ) .setMaster ( " local [*] " ) val spark = SparkSession.builder().config(conf).getOrCreate() val userDF: DataFrame = spark.read.json("E:/idea/spark3/in/user.json") userDF.createOrReplaceTempView("user") // Create an instance of aggregate functions Val myavg = new new myavg () // Register aggregate function spark.udf.register ( " udfavg " , myavg) // aggregate functions spark.sql ( " SELECT udfavg (Age) from User " ) .show } }
Print results are as follows:
+----------+
|myavg(age)|
+----------+
| 21.0|
+----------+
User-defined types strongly aggregate functions
By inheritance Aggregator to strong typing custom aggregation function, the same averaging wages.
First, create a custom aggregate class and inherits Aggregator in an abstract class, an abstract method implementation wherein
package sparksql.udf import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.expressions.Aggregator case class UserBean(name:String,age:Long) case class Buffer(var sum:Long,var count:Long) class MyAvg2 extends Aggregator[UserBean,Buffer,Double]{ //定义buffer的初始值 //A zero value for this aggregation. Should satisfy the property that any b + zero = b override def zero: Buffer = Buffer(0,0) /* When the calculation rules defined data arrival Combine two values to produce a new value. For performance, the function may modify `b` and return it instead of constructing new object for b. */ override def reduce(b: Buffer, a: UserBean): Buffer = { b.sum = b.sum+a.age b.count = b.count+1l b } //合并 override def merge(b1: Buffer, b2: Buffer): Buffer = { b1.sum = b1.sum+b2.sum b1.count = b1.count+b2.count b1 } // The end result the override DEF Finish (Reduction: Buffer): Double = { reduction.sum.toDouble/reduction.count } override def bufferEncoder: Encoder[Buffer] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble }
Wherein the aggregator generic type have the following meanings:
IN: Input Data Type
BUF: buffer data type
OUT: Output Data Type
@tparam IN The input type for the aggregation. * @tparam BUF The type of the intermediate value of the reduction. * @tparam OUT The type of the final output result. * @since 1.6.0 */abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
Then create an instance of a custom class, and he converted to TypedColumn type, used in the select method in this instance:
package sparksql.udf import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, TypedColumn} object Demo2 { def main(args: Array[String]): Unit = { // create SparkConf () and set the App Name Val conf = new new SparkConf (). SetAppName ( " sparlsql " ) .setMaster ( " local [*] " ) val spark = SparkSession.builder().config(conf).getOrCreate() val userDF: DataFrame = spark.read.json("E:/idea/spark3/in/user.json") import spark.implicits._ val userDS: Dataset[UserBean] = userDF.as[UserBean] // create an instance of the class MyAvg2 Val myavg2 = new new MyAvg2 () // the instance is converted to instance type TypedColumn Val udfavg: TypedColumn [the UserBean, Double] = myavg2.toColumn.name ( " myavg " ) // use userDS. The SELECT (udfavg) .Show spark.stop() } }
pit:
Creating dataframe to spark.read.json way when he met numeric types, the system will automatically be treated as a bigint. If you follow these variables put into int type, an exception is thrown:
Exception in thread "main" org.apache.spark.sql.AnalysisException: Cannot up cast `age` from bigint to int as it may truncate。
For example, you declare this case class UserBean:
case class UserBean(name:String,age:Int)
Like this to create dataframe and dataset:
val userDF: DataFrame = spark.read.json("E:/idea/spark3/in/user.json") import spark.implicits._ val userDS: Dataset[UserBean] = userDF.as[UserBean]
Exception occurs.
So long use or bigint received.