UDAF(用户自定义聚合函数)求众数

除了逐行处理数据的udf,还有比较常见的就是聚合多行处理udaf,自定义聚合函数。类比rdd编程就是map和reduce算子的区别。
自定义UDAF,需要extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction,并实现接口中的8个方法。
udaf写起来比较麻烦,我下面列一个之前写的取众数聚合函数,在我们通常在聚合统计的时候可能会受某条脏数据的影响。
举个栗子:
对于一个app日志聚合的时候,有id与ip,原则上一个id有一个ip,但是在多条数据里有一条ip是错误的或者为空的,这时候group能会聚合成两条数据了就,如果使用max,min对ip也进行聚合,那也不太合理,这时候可以进行投票,去类似多数对结果,从而聚合后只有一个设备。
废话少说,上代码:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * Description: 自定义聚合函数:众数(取列内频率最高的一条)
  */

class UDAFGetMode extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = {
    StructType(StructField("inputStr", StringType, true) :: Nil)
  }


  override def bufferSchema: StructType = {
    StructType(StructField("bufferMap", MapType(keyType = StringType, valueType = IntegerType), true) :: Nil)
  }

  override def dataType: DataType = StringType

  override def deterministic: Boolean = false

  //初始化map
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = scala.collection.immutable.Map[String, Int]()
  }

  //如果包含这个key则value+1,否则写入key,value=1
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val key = input.getAs[String](0)
    val immap = buffer.getAs[Map[String, Int]](0)
    val bufferMap = scala.collection.mutable.Map[String, Int](immap.toSeq: _*)
    val ret = if (bufferMap.contains(key)) {
      //      val new_value = bufferMap.get(key).get + 1
      val new_value = bufferMap(key) + 1
      bufferMap.put(key, new_value)
      bufferMap
    } else {
      bufferMap.put(key, 1)
      bufferMap
    }
    buffer.update(0, ret)

  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //合并两个map 相同的key的value累加
    val tempMap = (buffer1.getAs[Map[String, Int]](0) /: buffer2.getAs[Map[String, Int]](0)) {
      case (map, (k, v)) => map + (k -> (v + map.getOrElse(k, 0)))
    }
    buffer1.update(0, tempMap)
  }

  override def evaluate(buffer: Row): Any = {
    //返回值最大的key
    var max_value = 0
    var max_key = ""
    buffer.getAs[Map[String, Int]](0).foreach({ x =>
      val key = x._1
      val value = x._2
      if (value > max_value) {
        max_value = value
        max_key = key
      }
    })
    max_key
  }
}

测试类:

object UDAFTest {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local").appName(this.getClass.getSimpleName).getOrCreate()
    spark.udf.register("get_mode", new UDAFGetMode)
    import spark.implicits._
    val df = Seq(
      (1, "10.10.1.1", "start"),
      (1, "10.10.1.1", "search"),
      (2, "123.123.123.1", "search"),
      (1, "10.10.1.0", "stop"),
      (2, "123.123.123.1", "start")
    ).toDF("id", "ip", "action")

    df.createTempView("tb")
    spark.sql(s"select id,get_mode(ip) as u_ip,count(*) as cnt from tb group by id").show()
  }
}

猜你喜欢

转载自www.cnblogs.com/itboys/p/10626310.html