Hive中UDAF函数的Demo

场景:一个DEMO程序,统计分组的count、sum、dtl,并把结果以字符串拼接(count-sum-dtl)的形式输出,主要用到了结构体。
代码:
UDAF函数

package com.jd.pop.qc.udf;
import org.apache.commons.lang.ArrayUtils;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import java.util.ArrayList;
/**
 * Created by songhongwei on 2017-04-06.
 * 一个DEMO程序:统计分组的count、sum、dtl,并把结果以字符串拼接(count-sum-dtl)的形式输出
 */
public class CountSumDtl extends AbstractGenericUDAFResolver {
    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
        return new GenericUDAFCountSumDtlEvaluator();
    }
    public static class GenericUDAFCountSumDtlEvaluator extends GenericUDAFEvaluator{
        final static String split = "-";
        final static String comma = ",";
        private transient PrimitiveObjectInspector inputOI;
        private transient StructObjectInspector soi;
        private transient StructField countField;
        private transient  StructField sumField;
        private transient StructField contentField;
        private transient LongObjectInspector countFieldOI;
        private transient DoubleObjectInspector sumFieldOI;
        private transient StringObjectInspector contentFieldOI;
        private Object[] partialResult;
        public static class PartialResultAgg implements AggregationBuffer {
            long count;
            double sum;
            String content;
        }
        @Override
        public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
            super.init(mode, parameters);
            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
                inputOI = (PrimitiveObjectInspector)parameters[0];
            }else{
                soi =  (StructObjectInspector)parameters[0];
                countField = soi.getStructFieldRef("count");
                sumField = soi.getStructFieldRef("sum");
                contentField = soi.getStructFieldRef("content");
                countFieldOI = (LongObjectInspector)countField.getFieldObjectInspector();
                sumFieldOI = (DoubleObjectInspector)sumField.getFieldObjectInspector();
                contentFieldOI = (StringObjectInspector)contentField.getFieldObjectInspector();
            }
            //init output
            if(mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2){
                ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
                foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
                foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                foi.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
                ArrayList<String> fname = new ArrayList<String>();
                fname.add("count");
                fname.add("sum");
                fname.add("content");
                partialResult = new Object[]{new LongWritable(0),new DoubleWritable(0),new Text()};
                return ObjectInspectorFactory.getStandardStructObjectInspector(fname,foi);
            } else {
                return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
            }
        }
        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            PartialResultAgg partialResultAgg = new PartialResultAgg();
            reset(partialResultAgg);
            return partialResultAgg;
        }
        @Override
        public void reset(AggregationBuffer aggregationBuffer) throws HiveException {
            PartialResultAgg partialResultAgg = (PartialResultAgg)aggregationBuffer;
            partialResultAgg.count = 0;
            partialResultAgg.sum = 0;
            partialResultAgg.content = "";
        }
        @Override
        public void iterate(AggregationBuffer aggregationBuffer, Object[] objects) throws HiveException {
            if(ArrayUtils.isEmpty(objects))
                return;
            PartialResultAgg partialResultAgg = (PartialResultAgg) aggregationBuffer;
            Double cnt = PrimitiveObjectInspectorUtils.getDouble(objects[0], inputOI);
            partialResultAgg.count++;
            partialResultAgg.sum += cnt;
            partialResultAgg.content += cnt.longValue();
        }
        @Override
        public Object terminatePartial(AggregationBuffer aggregationBuffer) throws HiveException {
            PartialResultAgg partialResultAgg = (PartialResultAgg)aggregationBuffer;
            ((LongWritable)partialResult[0]).set(partialResultAgg.count);
            ((DoubleWritable)partialResult[1]).set(partialResultAgg.sum);
            ((Text) partialResult[2]).set(partialResultAgg.content);
            return partialResult;
        }
        @Override
        public void merge(AggregationBuffer aggregationBuffer, Object o) throws HiveException {
            PartialResultAgg partialResultAgg = (PartialResultAgg)aggregationBuffer;
            long oCount = countFieldOI.get(soi.getStructFieldData(o, countField));
            double oSum = sumFieldOI.get(soi.getStructFieldData(o,sumField));
            String oContent = contentFieldOI.getPrimitiveJavaObject(soi.getStructFieldData(o,contentField));
            partialResultAgg.count += oCount;
            partialResultAgg.sum += oSum;
            partialResultAgg.content += comma + oContent;
        }
        @Override
        public Object terminate(AggregationBuffer aggregationBuffer) throws HiveException {
            PartialResultAgg partialResultAgg = (PartialResultAgg)aggregationBuffer;
            String ret = partialResultAgg.count + split+partialResultAgg.sum+split+(partialResultAgg.content.startsWith(",")?partialResultAgg.content.substring(1):partialResultAgg.content);
            Text result = new Text();
            result.set(ret);
            return result;
        }
    }
}

使用过程:

add jar /home/mart_pop/tianhe/qc/qc_shop_qlty_sort/jar/pop-qc-hive-1.0.0.jar;
create temporary function countSumDtl as 'com.jd.pop.qc.udf.CountSumDtl';
select
        item_first_cate_cd,
        min(item_first_cate_name) item_first_cate_name,
        countSumDtl(sku_order_cnt) d_sku_order_cnt

from
        app.app_qc_shop_qlty_sort_topsis_dtl
where
        dt = '2017-03-31' and 
        item_first_cate_cd in (4051,4052)
group by
        item_first_cate_cd
;

执行效果:
这里写图片描述

最后给出网上的另一个例子(以collect_set源码分析)
http://www.lai18.com/content/2694127.html?from=cancel

参考资料
GenericUDAFCaseStudy
map到reduce中间的shuffle过程
Hive中ObjectInspector的作用
Hive中ObjectInspector作用
Hive内置数据类型
Hive自定义UDF/UDAF/UDTF中,如何获得List的ObjectInspector

猜你喜欢

转载自blog.csdn.net/otengyue/article/details/69499853