test code2

package org.test.udf

import com.google.gson.{Gson, GsonBuilder}
import org.apache.spark.sql.Row
import org.apache.spark.sql.api.java.UDF2
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scala.collection.mutable

/**
  * Created by yunzhi.lyz on 17-8-23.
  **/

object FunnelUdf {
 
  private[udf] var gson: Gson = new GsonBuilder().create

  class FunnelCalMerge extends UserDefinedAggregateFunction {
    override def inputSchema = StructType(StructField("item_id", ArrayType(LongType, false), true) ::
      StructField("item_timestamp", ArrayType(LongType, false), true) ::
      StructField("funnelDesc", StringType, true) ::
      StructField("windowLongernal", LongType, true) :: Nil)
    def bufferSchema: StructType = StructType(StructField("countArray", ArrayType(LongType, false), true) :: Nil)
    override def dataType: DataType = StructType(StructField("countArray", ArrayType(LongType, false), true) :: Nil)
    override def deterministic: Boolean = true
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = Array.fill[Long](10)(0)
    }
    override def update(buffer: MutableAggregationBuffer, inputrow: Row): Unit = {
      var itemIDArray = inputrow.getAs[mutable.WrappedArray[Long]](0)
      var itemTimestamp = inputrow.getAs[mutable.WrappedArray[Long]](1)
      val funnelRoute = inputrow.getString(2).split(",").map(x => x.toLong)
      val windowLongernal = inputrow.getLong(3)
      var bufferValue = buffer.getAs[mutable.WrappedArray[Long]](0)
      val r = funnelProcess(itemIDArray.toArray[Long], itemTimestamp.toArray[Long], funnelRoute, windowLongernal)
      for (i <- 0 until funnelRoute.length if r(i)) bufferValue(i) = bufferValue(i) + 1
      buffer(0) = bufferValue
    }
    override def merge(buffer: MutableAggregationBuffer, buffer2: Row): Unit = {
      val r1 = buffer.getAs[mutable.WrappedArray[Long]](0)
      val r2 = buffer2.getAs[mutable.WrappedArray[Long]](0)
      for (i <- 0 until 10) r1(i) = r1(i) + r2(i)
      buffer(0) = r1
    }
    override def evaluate(buffer: Row): Any = buffer
  }

  class FunnelCal extends UserDefinedAggregateFunction {
    override def inputSchema: StructType = StructType(
      StructField("item_id", LongType, true) ::
        StructField("item_timestamp", LongType, true) :: Nil)
    def bufferSchema: StructType = StructType(StructField("item_id", ArrayType(LongType), true) :: StructField("item_timestamp", ArrayType(LongType), true) :: Nil)
    override def dataType: DataType = StructType(StructField("item_id", ArrayType(LongType), true) :: StructField("item_timestamp", ArrayType(LongType), true) :: Nil)
    override def deterministic: Boolean = true
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = Array.empty[LongType]
      buffer(1) = Array.empty[LongType]
    }
    override def update(buffer: MutableAggregationBuffer, inputrow: Row): Unit = {
      buffer(0) = (buffer.getAs[mutable.WrappedArray[Long]](0)).toArray[Long].+:(inputrow.getAs[Long](0))
      buffer(1) = (buffer.getAs[mutable.WrappedArray[Long]](1)).toArray[Long].+:(inputrow.getAs[Long](1))
    }
    override def merge(buffer: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer(0) = (buffer.getAs[mutable.WrappedArray[Long]](0)).toArray[Long].++:((buffer2.getAs[mutable.WrappedArray[Long]](0)).toArray[Long])
      buffer(1) = (buffer.getAs[mutable.WrappedArray[Long]](1)).toArray[Long].++:((buffer2.getAs[mutable.WrappedArray[Long]](1)).toArray[Long])
    }
    override def evaluate(buffer: Row): Any = buffer
  }

  class FunnelCalMerge2 extends UserDefinedAggregateFunction {
    override def inputSchema = StructType(StructField("result", ArrayType(BooleanType), true) :: Nil)
    def bufferSchema: StructType = StructType(StructField("countArray", ArrayType(LongType, false), true) :: Nil)
    override def dataType: DataType = ArrayType(LongType, false)
    override def deterministic: Boolean = true
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = Array.fill[Long](10)(0)
    }
    override def update(buffer: MutableAggregationBuffer, inputrow: Row): Unit = {
      var bufferValue = buffer.getAs[mutable.WrappedArray[Long]](0)
      val r = (inputrow.getAs[mutable.WrappedArray[Boolean]](0)).toArray[Boolean]
      for (i <- 0 until r.length if r(i)) bufferValue(i) = bufferValue(i) + 1
      buffer(0) = bufferValue
    }
    override def merge(buffer: MutableAggregationBuffer, buffer2: Row): Unit = {
      val r1 = buffer.getAs[mutable.WrappedArray[Long]](0)
      val r2 = buffer2.getAs[mutable.WrappedArray[Long]](0)
      for (i <- 0 until 10) r1(i) = r1(i) + r2(i)
      buffer(0) = r1
    }
    override def evaluate(buffer: Row): Any = buffer(0)
  }


  class FunnelCal2 extends UserDefinedAggregateFunction {
    override def inputSchema: StructType = StructType(
      StructField("item_id", LongType, true) ::
        StructField("item_timestamp", LongType, true) ::
        StructField("funnelDesc", StringType, true) ::
        StructField("windowLongernal", LongType, true) :: Nil)
    def bufferSchema: StructType = StructType(
      StructField("item_id", ArrayType(LongType), true) ::
        StructField("item_timestamp", ArrayType(LongType), true) ::
        StructField("funnelDesc", StringType, true) ::
        StructField("windowLongernal", LongType, true) :: Nil)
    override def dataType: DataType = ArrayType(BooleanType, false)
    override def deterministic: Boolean = true
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = Array.empty[LongType]
      buffer(1) = Array.empty[LongType]
      buffer(2) = ""
      buffer(3) = 0l
    }
    override def update(buffer: MutableAggregationBuffer, inputrow: Row): Unit = {
      buffer(0) = (buffer.getAs[mutable.WrappedArray[Long]](0)).toArray[Long].+:(inputrow.getAs[Long](0))
      buffer(1) = (buffer.getAs[mutable.WrappedArray[Long]](1)).toArray[Long].+:(inputrow.getAs[Long](1))
      buffer(2) = inputrow.getString(2)
      buffer(3) = inputrow.getLong(3)
    }
    override def merge(buffer: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer(0) = (buffer.getAs[mutable.WrappedArray[Long]](0)).toArray[Long].++:((buffer2.getAs[mutable.WrappedArray[Long]](0)).toArray[Long])
      buffer(1) = (buffer.getAs[mutable.WrappedArray[Long]](1)).toArray[Long].++:((buffer2.getAs[mutable.WrappedArray[Long]](1)).toArray[Long])
      buffer(2) = buffer2.getString(2)
      buffer(3) = buffer2.getLong(3)
    }
    override def evaluate(buffer: Row): Any = {
      val funnelRoute = buffer.getString(2).split(",").map(x => x.toLong)
      val windowLongernal = buffer.getLong(3)
      val r = funnelProcess((buffer.getAs[mutable.WrappedArray[Long]](0)).toArray[Long],
        (buffer.getAs[mutable.WrappedArray[Long]](1)).toArray[Long], funnelRoute, windowLongernal)
      r
    }
  }

  class JsonInfoGet extends UDF2[String, String, String] {
    def call(jsonInfo: String, keyName: String): String = {
      var value: String = ""
      val map = CalJson.jsonToMap(jsonInfo, gson)
      if (map.containsKey(keyName))
        value = map.get(keyName).toString
      value
    }
  }

  def funnelProcess(dataItem: Array[Long], dataEventTime: Array[Long], rt: Array[Long], wd: Long): Array[Boolean] = {
    val result = Array.fill[Boolean](rt.length)(false)
    val data = dataItem.zip(dataEventTime)
    val sortData = data.sortBy(_._2)
    val indexArrayLength = rt.length - 1
    var startTimeArray = Array.fill[Long](rt.length)(0l)
    val indexMap = rt.map(item => item -> rt.indexOf(item)).toMap
    var notSatisfy = true
    for (itemWithTimeKv <- sortData if notSatisfy) {
      val itemIndex = indexMap(itemWithTimeKv._1)
      // first item
      if (itemIndex == 0) {
        startTimeArray(0) = itemWithTimeKv._2;
        result(0) = true
      } // pre item exists?
      else if (startTimeArray(itemIndex - 1) != 0) {
        // in range
        if ((itemWithTimeKv._2 - startTimeArray(itemIndex - 1)) < wd) {
          startTimeArray(itemIndex) = startTimeArray(itemIndex - 1)
          result(itemIndex) = true
          // out range
        } else
          startTimeArray(itemIndex - 1) = 0
      }
      if (result(indexArrayLength) == true) notSatisfy = false
    }
    result
  }


}

猜你喜欢

转载自lingzhi007.iteye.com/blog/2391980