Spark-Spark SQLカスタム関数UDF UDAF UDTF

 

 カスタム関数分類

ハイブのカスタム関数と同様に、sparkはカスタム関数を使用して新しい関数を実装することもできます。

Sparkには3種類のカスタム関数があります

1.UDF(ユーザー定義関数)

              行を入力し、行を出力します

2.UDAF(ユーザー定義集計関数)

              複数行入力、1行出力

3.UDTF(ユーザー定義のテーブル生成関数)

              1行入力し、複数行を出力

 

カスタムUDF

●需要

udf.txtのデータ形式は次のとおりです。

こんにちは

ABC

調査

小さい

 

カスタムUDF関数を使用して、データの各行を大文字に変換します

t_wordから値、smallToBig(値)を選択します

 

●コードのデモ

package cn.itcast.sql

import org.apache.spark.SparkContext
import org.apache.spark.sql.{Dataset, SparkSession}


object UDFDemo {
  def main(args: Array[String]): Unit = {
    //1.创建SparkSession
    val spark: SparkSession = SparkSession.builder().master("local[*]").appName("SparkSQL").getOrCreate()
    val sc: SparkContext = spark.sparkContext
    sc.setLogLevel("WARN")
    //2.读取文件
    val fileDS: Dataset[String] = spark.read.textFile("D:\\data\\udf.txt")
    fileDS.show()
    /*
    +----------+
    |     value|
    +----------+
    |helloworld|
    |       abc|
    |     study|
    | smallWORD|
    +----------+
     */
   /*
    将每一行数据转换成大写
    select value,smallToBig(value) from t_word
    */
    //注册一个函数名称为smallToBig,功能是传入一个String,返回一个大写的String
    spark.udf.register("smallToBig",(str:String) => str.toUpperCase())
    fileDS.createOrReplaceTempView("t_word")
    //使用我们自己定义的函数
    spark.sql("select value,smallToBig(value) from t_word").show()
    /*
    +----------+---------------------+
    |     value|UDF:smallToBig(value)|
    +----------+---------------------+
    |helloworld|           HELLOWORLD|
    |       abc|                  ABC|
    |     study|                STUDY|
    | smallWORD|            SMALLWORD|
    +----------+---------------------+
     */
    sc.stop()
    spark.stop()
  }
}

 

カスタムUDAF

●需要

udaf.jsonのデータ内容は以下の通りです

{"名前": "マイケル"、 "給与":3000}

{"名前": "アンディ"、 "給与":4500}

{"名前": "ジャスティン"、 "給与":3500}

{"名前": "ベルタ"、 "給与":4000}

平均賃金

 

●UserDefinedAggregateFunctionメソッドの継承手順

inputSchema:入力データのタイプ

bufferSchema:中間結果を生成するデータ型

dataType:返される最終的な結果タイプ

確定的:一貫性を確保するために、通常はtrueを使用します

initialize:初期値を指定します

update:データが操作に参加するたびに中間結果を更新します(更新は各パーティションでの操作に相当します)

マージ:グローバル集計(各パーティションの結果を集計)

評価:最終結果を計算する

 

●コードのデモ

package cn.itcast.sql

import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}


object UDAFDemo {
  def main(args: Array[String]): Unit = {
    //1.获取sparkSession
    val spark: SparkSession = SparkSession.builder().appName("SparkSQL").master("local[*]").getOrCreate()
    val sc: SparkContext = spark.sparkContext
    sc.setLogLevel("WARN")
    //2.读取文件
    val employeeDF: DataFrame = spark.read.json("D:\\data\\udaf.json")
    //3.创建临时表
    employeeDF.createOrReplaceTempView("t_employee")
    //4.注册UDAF函数
    spark.udf.register("myavg",new MyUDAF)
    //5.使用自定义UDAF函数
    spark.sql("select myavg(salary) from t_employee").show()
    //6.使用内置的avg函数
    spark.sql("select avg(salary) from t_employee").show()
  }
}
class MyUDAF extends UserDefinedAggregateFunction{
  //输入的数据类型的schema
  override def inputSchema: StructType = {
     StructType(StructField("input",LongType)::Nil)
  }
  //缓冲区数据类型schema,就是转换之后的数据的schema
  override def bufferSchema: StructType = {
    StructType(StructField("sum",LongType)::StructField("total",LongType)::Nil)
  }
  //返回值的数据类型
  override def dataType: DataType = {
    DoubleType
  }
  //确定是否相同的输入会有相同的输出
  override def deterministic: Boolean = {
    true
  }
  //初始化内部数据结构
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }
  //更新数据内部结构,区内计算
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //所有的金额相加
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    //一共有多少条数据
    buffer(1) = buffer.getLong(1) + 1
  }
  //来自不同分区的数据进行合并,全局合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) =buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  //计算输出数据值
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble / buffer.getLong(1)
  }
}

 

続く!適切なUDTFケースを探しています!

 

元の記事113件を公開 賞賛200件 閲覧回数170,000回

おすすめ

転載: blog.csdn.net/weixin_44036154/article/details/105460958