Hive UDF array_struct_sort sorts Array<Struct>

1. UDF description

array_struct_sort(array(struct1,struct2,...), string sortField): Returns the passed array struct, ordered by the given field. 
Sort the given Array<Struct> by sortField and return.

Two, the code

package com.scb.dss.udf;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde.Constants;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.io.Text;

import java.util.*;

import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.LIST;

@Description(name = "array_struct_sort",
        value = "_FUNC_(array(struct1,struct2,...), string sortField) - "
                + "Returns the passed array struct, ordered by the given field",
        extended = "Example:\n"
                + "  > SELECT class, array_struct_sort(collect_list(struct_t), 'age') as struct_array\n" +
                "    FROM (\n" +
                "        SELECT '1' as class, named_struct('name', 'N003', 'age', '20') as struct_t\n" +
                "        union all \n" +
                "        SELECT '2' as class, named_struct('name', 'N001', 'age', '18') as struct_t\n" +
                "        union all \n" +
                "        SELECT '1' as class, named_struct('name', 'N002', 'age', '19') as struct_t\n" +
                "        union all\n" +
                "        SELECT '2' as class, named_struct('name', 'N000', 'age', '17') as struct_t\n" +
                "    ) as test_data\n" +
                "    group by class;\n")
public class UDFArrayStructSort extends GenericUDF {
    protected ObjectInspector[] argumentOIs;

    ListObjectInspector loi;
    StructObjectInspector elOi;

    // cache comparators for performance
    Map<String, Comparator> comparatorCache = new HashMap<String, Comparator>();

    @Override
    public ObjectInspector initialize(ObjectInspector[] ois) throws UDFArgumentException {
        // all common initialization
        argumentOIs = ois;

        // clear comparator cache from previous invokations
        comparatorCache.clear();

        return checkAndReadObjectInspectors(ois);
    }

    /**
     * Utility method to check that an object inspector is of the correct type,
     * and returns its element object inspector
     *
     * @param ois
     * @return
     * @throws UDFArgumentTypeException
     */
    protected ListObjectInspector checkAndReadObjectInspectors(ObjectInspector[] ois)
            throws UDFArgumentTypeException, UDFArgumentException {
        // check number of arguments. We only accept two,
        // the list of struct to sort and the name of the struct field
        // to sort by
        if (ois.length != 2) {
            throw new UDFArgumentException("2 arguments needed, found " + ois.length);
        }

        // first argument must be a list/array
        if (!ois[0].getCategory().equals(LIST)) {
            throw new UDFArgumentTypeException(0, "Argument 1"
                    + " of function " + this.getClass().getCanonicalName() + " must be " + Constants.LIST_TYPE_NAME
                    + ", but " + ois[0].getTypeName()
                    + " was found.");
        }

        // a list/array is read by a LIST object inspector
        loi = (ListObjectInspector) ois[0];

        // a list has an element type associated to it
        // elements must be structs for this UDF
        if (loi.getListElementObjectInspector().getCategory() != ObjectInspector.Category.STRUCT) {
            throw new UDFArgumentTypeException(0, "Argument 1"
                    + " of function " + this.getClass().getCanonicalName() + " must be an array of structs " +
                    " but is an array of " + loi.getListElementObjectInspector().getCategory().name());
        }

        // store the object inspector for the elements
        elOi = (StructObjectInspector) loi.getListElementObjectInspector();

        // returns the same object inspector
        return loi;
    }

    // factory method for cached comparators
    Comparator getComparator(Text field) {
        if (!comparatorCache.containsKey(field.toString())) {
            comparatorCache.put(field.toString(), new StructFieldComparator(field.toString()));
        }
        return comparatorCache.get(field.toString());
    }

    @Override
    public Object evaluate(DeferredObject[] dos) throws HiveException {
        // get list
        if (dos == null || dos.length != 2) {
            throw new HiveException("received " + (dos == null ? "null" :
                    Integer.toString(dos.length) + " elements instead of 2"));
        }

        // each object is supposed to be a struct
        // we make a shallow copy of the list. We don't want to sort
        // the list in place since the object could be used elsewhere in the
        // hive query
        ArrayList al = new ArrayList(loi.getList(dos[0].get()));

        // sort with our comparator, then return
        // note that we could get a different field to sort by for every
        // invocation
        Collections.sort(al, getComparator((Text) dos[1].get()));

        return al;
    }

    @Override
    public String getDisplayString(String[] children) {
        return (children == null ? null : this.getClass().getCanonicalName() + "(" + children[0] + "," + children[1] + ")");
    }

    // to sort a list , we must supply our comparator
    public class StructFieldComparator implements Comparator {
        StructField field;

        public StructFieldComparator(String fieldName) {
            field = elOi.getStructFieldRef(fieldName);
        }

        public int compare(Object o1, Object o2) {

            // ok..so both not null
            Object f1 = elOi.getStructFieldData(o1, field);
            Object f2 = elOi.getStructFieldData(o2, field);
            // compare using hive's utility functions
            return ObjectInspectorUtils.compare(f1, field.getFieldObjectInspector(),
                    f2, field.getFieldObjectInspector());
        }
    }

}

3. Test

The test data is as follows:

class struct
1 {"name":"N003","age":"20"}
2 {"name":"N001","age":"18"}
1 {"name":"N002","age":"19"}
2 {"name":"N000","age":"17"}

Test code:

SELECT class, array_struct_sort(collect_list(struct_t), 'age') as struct_array
    FROM (
        SELECT '1' as class, named_struct('name', 'N003', 'age', '20') as struct_t
        union all 
        SELECT '2' as class, named_struct('name', 'N001', 'age', '18') as struct_t
        union all 
        SELECT '1' as class, named_struct('name', 'N002', 'age', '19') as struct_t
        union all
        SELECT '2' as class, named_struct('name', 'N000', 'age', '17') as struct_t
    ) as test_data
    group by class;

The test results are as follows:

Combined with the Hive UDAF collect_map in the previous section, we can aggregate and sort MAP<STRING, ARRAY<STRUCT<x,x>>>.

SELECT class, collect_map(class, struct_array) as res
FROM (
    SELECT class, array_struct_sort(collect_list(struct_t), 'age') as struct_array
    FROM (
        SELECT '1' as class, named_struct('name', 'N003', 'age', '20') as struct_t
        union all 
        SELECT '2' as class, named_struct('name', 'N001', 'age', '18') as struct_t
        union all 
        SELECT '1' as class, named_struct('name', 'N002', 'age', '19') as struct_t
        union all
        SELECT '2' as class, named_struct('name', 'N000', 'age', '17') as struct_t
    ) as test_data
    group by class
) as tmp
group by class
;

 

 Compared with the results in the previous section, the structs produced this time can be sorted according to age.

The type of the res field is MAP<STRING, ARRAY<STRUCT<field1,field2,field3>>>, if you want to get the name of the newest students in a class, the code is as follows:

select res['1'][0].name
from (
SELECT class, collect_map(class, struct_array) as res
FROM (
    SELECT class, array_struct_sort(collect_list(struct_t), 'age') as struct_array
    FROM (
        SELECT '1' as class, named_struct('name', 'N003', 'age', '20') as struct_t
        union all 
        SELECT '2' as class, named_struct('name', 'N001', 'age', '18') as struct_t
        union all 
        SELECT '1' as class, named_struct('name', 'N002', 'age', '19') as struct_t
        union all
        SELECT '2' as class, named_struct('name', 'N000', 'age', '17') as struct_t
    ) as test_data
    group by class
) as tmp
group by class
) as t
;

You can add a where condition to filter NULL 

4. Improvement

The above code only supports ascending sorting, so what if descending order is required? We can use the Collections.reverseOrder() method to achieve descending order. The complete code is as follows:

package com.scb.dss.udf;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde.Constants;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.io.BooleanWritable;
import org.apache.hadoop.io.Text;

import java.util.*;

import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.LIST;

@Description(name = "array_struct_sort",
        value = "_FUNC_(array(struct1,struct2,...), string sortField, bool asc) - "
                + "Returns the passed array struct, ordered by the given field. The default sorting method is ascending",
        extended = "Example:\n"
                + "  > SELECT class, array_struct_sort(collect_list(struct_t), 'age', true) as struct_array\n" +
                "    FROM (\n" +
                "        SELECT '1' as class, named_struct('name', 'N003', 'age', '20') as struct_t\n" +
                "        union all \n" +
                "        SELECT '2' as class, named_struct('name', 'N001', 'age', '18') as struct_t\n" +
                "        union all \n" +
                "        SELECT '1' as class, named_struct('name', 'N002', 'age', '19') as struct_t\n" +
                "        union all\n" +
                "        SELECT '2' as class, named_struct('name', 'N000', 'age', '17') as struct_t\n" +
                "    ) as test_data\n" +
                "    group by class;\n")
public class UDFArrayStructSort extends GenericUDF {
    protected ObjectInspector[] argumentOIs;

    ListObjectInspector loi;
    StructObjectInspector elOi;

    // cache comparators for performance
    Map<String, Comparator> comparatorCache = new HashMap<String, Comparator>();

    @Override
    public ObjectInspector initialize(ObjectInspector[] ois) throws UDFArgumentException {
        // all common initialization
        argumentOIs = ois;

        // clear comparator cache from previous invokations
        comparatorCache.clear();

        return checkAndReadObjectInspectors(ois);
    }

    /**
     * Utility method to check that an object inspector is of the correct type,
     * and returns its element object inspector
     *
     * @param ois
     * @return
     * @throws UDFArgumentTypeException
     */
    protected ListObjectInspector checkAndReadObjectInspectors(ObjectInspector[] ois)
            throws UDFArgumentTypeException, UDFArgumentException {
        // check number of arguments. We only accept two,
        // the list of struct to sort and the name of the struct field
        // to sort by
        if (ois.length != 3) {
            throw new UDFArgumentException("3 arguments needed, found " + ois.length);
        }

        // first argument must be a list/array
        if (!ois[0].getCategory().equals(LIST)) {
            throw new UDFArgumentTypeException(0, "Argument 1"
                    + " of function " + this.getClass().getCanonicalName() + " must be " + Constants.LIST_TYPE_NAME
                    + ", but " + ois[0].getTypeName()
                    + " was found.");
        }

        // a list/array is read by a LIST object inspector
        loi = (ListObjectInspector) ois[0];

        // a list has an element type associated to it
        // elements must be structs for this UDF
        if (loi.getListElementObjectInspector().getCategory() != ObjectInspector.Category.STRUCT) {
            throw new UDFArgumentTypeException(0, "Argument 1"
                    + " of function " + this.getClass().getCanonicalName() + " must be an array of structs " +
                    " but is an array of " + loi.getListElementObjectInspector().getCategory().name());
        }

        // store the object inspector for the elements
        elOi = (StructObjectInspector) loi.getListElementObjectInspector();

        // returns the same object inspector
        return loi;
    }

    // factory method for cached comparators
    Comparator getComparator(Text field) {
        if (!comparatorCache.containsKey(field.toString())) {
            comparatorCache.put(field.toString(), new StructFieldComparator(field.toString()));
        }
        return comparatorCache.get(field.toString());
    }

    @Override
    public Object evaluate(DeferredObject[] dos) throws HiveException {
        // get list
        if (dos == null || dos.length != 3) {
            throw new HiveException("received " + (dos == null ? "null" :
                    Integer.toString(dos.length) + " elements instead of 3"));
        }

        ArrayList al = new ArrayList(loi.getList(dos[0].get()));

        if (((BooleanWritable) dos[2].get()).get()) {
            Collections.sort(al, getComparator((Text) dos[1].get()));
        } else {
            Collections.sort(al, Collections.reverseOrder(getComparator((Text) dos[1].get())));
        }
        return al;
    }

    @Override
    public String getDisplayString(String[] children) {
        return (children == null ? null : this.getClass().getCanonicalName() + "(" + children[0] + "," + children[1] + ")");
    }

    public class StructFieldComparator implements Comparator {
        StructField field;

        public StructFieldComparator(String fieldName) {
            field = elOi.getStructFieldRef(fieldName);
        }

        public int compare(Object o1, Object o2) {

            Object f1 = elOi.getStructFieldData(o1, field);
            Object f2 = elOi.getStructFieldData(o2, field);
            // compare using hive's utility functions
            return ObjectInspectorUtils.compare(f1, field.getFieldObjectInspector(),
                    f2, field.getFieldObjectInspector());
        }
    }

}

Test code:

SELECT class, collect_map(class, struct_array) as res
FROM (
    SELECT class, array_struct_sort(collect_list(struct_t), 'age', false) as struct_array
    FROM (
        SELECT '1' as class, named_struct('name', 'N003', 'age', '20') as struct_t
        union all 
        SELECT '2' as class, named_struct('name', 'N001', 'age', '18') as struct_t
        union all 
        SELECT '1' as class, named_struct('name', 'N002', 'age', '19') as struct_t
        union all
        SELECT '2' as class, named_struct('name', 'N000', 'age', '17') as struct_t
    ) as test_data
    group by class
) as tmp
group by class
;

Test screenshot:

 5. Reference documents

» Structured data in Hive: a generic UDF to sort arrays of structs Roberto Congiu's blog

Analysis of Collections.sort

 

Guess you like

Origin blog.csdn.net/qq_37771475/article/details/126371687