HIVE UDF —— matchNWords

Recently, due to business needs, it is necessary to design a UDF to count the number of matches of words in two strings. The definition of a word (word) is as follows: Chinese uses a single word as a word, and English uses a continuous string of letters separated by spaces as a word.

1. The first step: word segmentation

**
     * Split String to Words. 英文以一个单词作为word, 中文以单个字作为word.
     *
     * @param src
     * @return ArrayList<String>
     */
    public static ArrayList<String> splitWords(String src) {
        // Step1: Split String to Character. Noted: Chinese takes two bytes.
        char[] charArray = src.toCharArray();

        ArrayList<String> result = new ArrayList<>();

        // Step2: split Character to words.
        for (int i = 0, j = 0, e = 0; i < charArray.length; ) {
            // 如果是空格(全角以及半角)
            if (charArray[i] == '\u0020' || charArray[i] == '\u3000') {
                i++;
                j++;
                e = i;
            }
            // 如果是英文字符
            else if (charArray[i] >= 0x0000 && charArray[i] <= 0x00FF) {
                j++;
                // 数组下标越界处理
                if (j >= charArray.length) {
                    StringBuffer sb = new StringBuffer();
                    for (int t = e; t < j; t++) {
                        sb.append(charArray[t]);
                    }
                    result.add(sb.toString());
                    i++;
                    e = i;
                }
                // 如果下一位也是英文字符, 除了英文空格
                else if (charArray[j] != '\u0020' && charArray[j] >= 0x0000 && charArray[j] <= 0x00FF) {
                    i++;
                }
                // 其他情况: 中文或空格. 到此,单个英文单词分词完毕
                else {
                    StringBuffer sb = new StringBuffer();
                    for (int t = e; t < j; t++) {
                        sb.append(charArray[t]);
                    }
                    result.add(sb.toString());
                    e = i;
                    i++;
                }
            }
            // 其他情况:中文
            else {
                StringBuffer sb = new StringBuffer();
                sb.append(charArray[i]);
                result.add(sb.toString());
                i++;
                j++;
                e = i;
            }
        }
        return result;
    }

The idea here is as follows:

  1. The variable i is used to record the scanning position of the string, the variable j is used to record the next digit of the variable i, and the variable e is used to record the start digit of the English word.
  2. If i is an English character, then j++, to judge the character of j, if it is English, then i++ (the variable i is used to record the end of the English word), if it is other, then a single English word ends, take e to j as the word, then e=i, i++
  3. If the i digit is Chinese, store the value directly, i++, j++, e=i
  4. If the i bit is a space, no value is stored, i++, j++, e=i

The test is as follows:

    @Test
    public void testSplitWords() {
        ArrayList<String> result = UDFMatchNWords.splitWords("China 中国电力公司CN");
        for (String s : result) {
            System.out.println(s);
        }
        Assert.assertEquals(8, result.size());

        System.out.println("---------------------------");

        ArrayList<String> result2 = UDFMatchNWords.splitWords("中国\u3000China电力公司 之 上海分公司 BB");
        for (String s : result2) {
            System.out.println(s);
        }
        Assert.assertEquals(14, result2.size());
    }

2. The second step, Array to HashMap

Because for repeated words (word), the minimum number of matches needs to be taken. Therefore, the Collectors.toMap method is directly used here, with the word as the key and the number of occurrences as the value, and then when the key conflicts, the value is added.

Reference document: Collectors.toMap usage skills (List to Map is super convenient)

    /**
     * Convert ArrayList To HashMap. The Key is word, and the value is the count of word.
     *
     * @param list
     * @return HashMap<String, Integer>
     */
    public static HashMap<String, Integer> convertListToHashMap(List<String> list) {
        HashMap<String, Integer> map = list.stream().collect(Collectors.toMap(
                String::toString, v -> 1, (v1, v2) -> (v1 + v2), HashMap::new
        ));
        return map;
    }

test

    @Test
    public void testConvertListToHashMap() {
        List<String> list = Arrays.asList("阿", "巴", "里", "巴", "巴");
        Map<String, Integer> map = UDFMatchNWords.convertListToHashMap(list);
        map.forEach((k, v) -> System.out.println("word:" + k + ", count:" + v));
        Assert.assertEquals(new Integer(3), map.get("巴"));
        Assert.assertEquals(new Integer(1), map.get("阿"));
    }

 3. Matching

After implementing the splitWords and convertListToHashMap methods, you only need to call these two methods in evaluate, and then do the matching. Put the complete code directly here.

package com.scb.dss.udf;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;

@Description(name = "matchNWords",
        value = "_FUNC_(str1, str2, n) - Return TRUE if there are n words matched between str1 and str2")
public class UDFMatchNWords extends UDF {

    /**
     * Return TRUE if there are n words matched between str1 and str2
     * @param str1
     * @param str2
     * @param n
     * @return
     */
    public boolean evaluate(String str1, String str2, Integer n) {
        // Step1: split words
        ArrayList<String> s1 = splitWords(str1);
        ArrayList<String> s2 = splitWords(str2);

        // Step2: match n words
        // 如果str1或者str2的单词小于n个, 则直接返回false
        if (s1.size() < n || s2.size() < n) {
            return false;
        }

        // convert to HashMap
        HashMap<String, Integer> map1 = convertListToHashMap(s1);
        HashMap<String, Integer> map2 = convertListToHashMap(s2);

        int matchCnt = 0;

        for (String s : map1.keySet()) {
            if (map2.containsKey(s)) {
                matchCnt += Math.min(map1.get(s), map2.get(s));
                // 短路原则,优化算法
                if (matchCnt == n) {
                    return true;
                }
            }
        }
        return false;
    }

    /**
     * return the count of match words between str1 and str2
     * @param str1
     * @param str2
     * @return match words count
     */
    public Integer evaluate(String str1, String str2) {
        if (str1 == null || str2 == null) {
            return null;
        }

        // Step1: split words
        ArrayList<String> s1 = splitWords(str1);
        ArrayList<String> s2 = splitWords(str2);

        // Step2: match n words
        // convert to HashMap
        HashMap<String, Integer> map1 = convertListToHashMap(s1);
        HashMap<String, Integer> map2 = convertListToHashMap(s2);

        int matchCnt = 0;

        for (String s : map1.keySet()) {
            if (map2.containsKey(s)) {
                matchCnt += Math.min(map1.get(s), map2.get(s));
            }
        }
        return matchCnt;
    }

    /**
     * Convert ArrayList To HashMap. The Key is word, and the value is the count of word.
     *
     * @param list
     * @return HashMap<String, Integer>
     */
    public static HashMap<String, Integer> convertListToHashMap(List<String> list) {
        HashMap<String, Integer> map = list.stream().collect(Collectors.toMap(
                String::toString, v -> 1, (v1, v2) -> (v1 + v2), HashMap::new
        ));
        return map;
    }

    /**
     * Split String to Words. 英文以一个单词作为word, 中文以单个字作为word.
     *
     * @param src
     * @return ArrayList<String>
     */
    public static ArrayList<String> splitWords(String src) {
        // Step1: Split String to Character. Noted: Chinese takes two bytes.
        char[] charArray = src.toCharArray();

        ArrayList<String> result = new ArrayList<>();

        // Step2: split Character to words.
        for (int i = 0, j = 0, e = 0; i < charArray.length; ) {
            // 如果是空格(全角以及半角)
            if (charArray[i] == '\u0020' || charArray[i] == '\u3000') {
                i++;
                j++;
                e = i;
            }
            // 如果是英文字符
            else if (charArray[i] >= 0x0000 && charArray[i] <= 0x00FF) {
                j++;
                // 数组下标越界处理
                if (j >= charArray.length) {
                    StringBuffer sb = new StringBuffer();
                    for (int t = e; t < j; t++) {
                        sb.append(charArray[t]);
                    }
                    result.add(sb.toString());
                    i++;
                    e = i;
                }
                // 如果下一位也是英文字符, 除了英文空格
                else if (charArray[j] != '\u0020' && charArray[j] >= 0x0000 && charArray[j] <= 0x00FF) {
                    i++;
                }
                // 其他情况: 中文或空格. 到此,单个英文单词分词完毕
                else {
                    StringBuffer sb = new StringBuffer();
                    for (int t = e; t < j; t++) {
                        sb.append(charArray[t]);
                    }
                    result.add(sb.toString());
                    e = i;
                    i++;
                }
            }
            // 其他情况:中文
            else {
                StringBuffer sb = new StringBuffer();
                sb.append(charArray[i]);
                result.add(sb.toString());
                i++;
                j++;
                e = i;
            }
        }
        return result;
    }
}

test class

package com.scb.dss.udf;

import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

public class UDFMatchNWordsTest {

    private final UDFMatchNWords matchNWords = new UDFMatchNWords();

    @Test
    public void testSplitWords() {
        ArrayList<String> result = UDFMatchNWords.splitWords("China 中国电力公司CN");
        for (String s : result) {
            System.out.println(s);
        }
        Assert.assertEquals(8, result.size());

        System.out.println("---------------------------");

        ArrayList<String> result2 = UDFMatchNWords.splitWords("中国\u3000China电力公司 之 上海分公司 BB");
        for (String s : result2) {
            System.out.println(s);
        }
        Assert.assertEquals(14, result2.size());
    }

    @Test
    public void testConvertListToHashMap() {
        List<String> list = Arrays.asList("阿", "巴", "里", "巴", "巴");
        Map<String, Integer> map = UDFMatchNWords.convertListToHashMap(list);
        map.forEach((k, v) -> System.out.println("word:" + k + ", count:" + v));
        Assert.assertEquals(new Integer(3), map.get("巴"));
        Assert.assertEquals(new Integer(1), map.get("阿"));
    }

    @Test
    public void testEvaluate() {
        // 当word匹配上时,取最小count作为匹配次数
        Assert.assertEquals(true, matchNWords.evaluate("阿里巴巴", "巴土巴士", 2));
        Assert.assertEquals(true, matchNWords.evaluate("阿里巴巴", "巴士公司", 1));

        // 如果str1或者str2的单词小于n个, 则直接返回false
        Assert.assertEquals(false, matchNWords.evaluate("阿里巴巴", "阿里巴巴有限公司", 5));

        // 复杂匹配: 含中英文
        Assert.assertEquals(true, matchNWords.evaluate("China 国网天津电力公司 Co., Ltd.", "国网电力公司 China 天津分公司 Co., Ltd.", 11));
    }

    @Test
    public void testEvaluate2() {
        Assert.assertEquals(new Integer(2), matchNWords.evaluate("hello hi world", "hi world"));
    }
}

4. Publish UDF

Refer to the previous section  Hive UDF <User-Defined Function> Getting Started  to publish UDF to HIVE.

select matchNWords('Hello Hi World', 'Hello World');

select matchNWords('Hello Hi World', 'Hello World', 2);

 The evaluate of Hive UDF supports overloading. For the specific execution process, please refer to: Analysis of the use of evaluate method of HiveUDF

Guess you like

Origin blog.csdn.net/qq_37771475/article/details/121695144