Java实现HiveSQL Parser

版权声明:学习交流为主,未经博主同意禁止转载,禁止用于商用。 https://blog.csdn.net/u012965373/article/details/83750759

import com.xxxx.model.SQLParserResult;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.parse.*;
import java.util.*;

/**
 * @author yangxin-ryan
 * HiveParse工具类
 * 根据解析SQL语句,解析出其中的数据库名,表名,具体操作。
 * 处理:获取SELECT操作中的表和列的相关操作。其他操作这判断到表级别。
 * 实现思路:对AST深度优先遍历,遇到操作的token则判断当前的操作,
 * 遇到TOK_TAB或TOK_TABREF则判断出当前操作的表,遇到子句则压栈当前处理,处理子句。
 * 子句处理完,栈弹出。
 */
public class HiveParseUtil {
    private static Log LOG = LogFactory.getLog(HiveParseUtil.class);
    private static final String UNKNOWN = "UNKNOWN";
    private Map<String, String> tableAlias = new HashMap<>();
    private Map<String, String> cols = new TreeMap<>();
    private Map<String, String> colAlais = new TreeMap<>();
    private Set<String> tableOpers = new HashSet<>();
    private Set<String> tables = new HashSet<>();
    private Stack<String> tableNameStack = new Stack<>();
    private Stack<Oper> operStack = new Stack<>();
    private String nowQueryTable = "";//定义及处理不清晰,修改为query或from节点对应的table集合或许好点。目前正在查询处理的表可能不止一个。
    private Oper oper;
    private boolean joinClause = false;
    public static final String OPER_SEPARATOR = "\t";

    public void parse(String sql) throws ParseException {
        ParseDriver pd = new ParseDriver();
        ASTNode ast = pd.parse(sql);
        parseIteral(ast);
    }

    private enum Oper {
        SELECT, INSERT, DROP, TRUNCATE, LOAD, CREATETABLE, ALTER, CREATE
    }

    public Map<String, String> getTableAlias() {
        return tableAlias;
    }

    public Map<String, String> getCols() {
        return cols;
    }

    public Map<String, String> getColAlais() {
        return colAlais;
    }

    public Set<String> getTables() {
        return tables;
    }

    public Set<String> getTableOpers() {
        return tableOpers;
    }

    public Stack<String> getTableNameStack() {
        return tableNameStack;
    }

    public Stack<Oper> getOperStack() {
        return operStack;
    }

    public String getNowQueryTable() {
        return nowQueryTable;
    }

    public Oper getOper() {
        return oper;
    }

    /**
     * 解析对应的SQL成为树的节点
     * @param ast
     * @return
     */
    public Set<String> parseIteral(ASTNode ast) {
        Set<String> set = new HashSet<>();
        prepareToParseCurrentNodeAndChilds(ast);
        set.addAll(parseChildNodes(ast));
        set.addAll(parseCurrentNode(ast, set));
        endParseCurrentNode(ast);
        return set;
    }

    /**
     * 解析SQL返回对应的数据库名、数据表名字、具体操作
     * @param querySQL
     * @return
     * @throws ParseException
     */
    public static List<SQLParserResult> sqlParser(String querySQL) throws ParseException{
        LOG.info("Parser Engine HiveParser ...");
        List<SQLParserResult> parserResultList = new ArrayList<>();
        HiveParseUtil hp = new HiveParseUtil();
        hp.parse(querySQL);
        Set<String> tableOpt = hp.getTableOpers();
        Iterator<String> iterator = tableOpt.iterator();
        while (iterator.hasNext()){
            SQLParserResult parserResult = new SQLParserResult();
            String contentTemp = iterator.next();
            String dbName = contentTemp.split("\t")[0].split("\\.")[0];
            String tableName = contentTemp.split("\t")[0].split("\\.")[1];
            String operation = contentTemp.split("\t")[1];
            LOG.info("dbName = " + dbName);
            LOG.info("tableName = " + tableName);
            LOG.info("operation = " + operation);
            parserResult.setDbName(dbName);
            parserResult.setTableName(tableName);
            parserResult.setOperation(operation);
            parserResultList.add(parserResult);
        }
        return parserResultList;
    }

    /**
     * 解析树中的上下游节点
     * @param ast
     */
    private void prepareToParseCurrentNodeAndChilds(ASTNode ast) {
        if (ast.getToken() != null) {
            switch (ast.getToken().getType()) {
                case HiveParser.TOK_RIGHTOUTERJOIN:
                case HiveParser.TOK_LEFTOUTERJOIN:
                case HiveParser.TOK_JOIN:
                    joinClause = true;
                    break;
                case HiveParser.TOK_QUERY:
                    tableNameStack.push(nowQueryTable);
                    operStack.push(oper);
                    nowQueryTable = "";
                    oper = Oper.SELECT;
                    break;
                case HiveParser.TOK_INSERT:
                    tableNameStack.push(nowQueryTable);
                    operStack.push(oper);
                    oper = Oper.INSERT;
                    break;
                case HiveParser.TOK_SELECT:
                    tableNameStack.push(nowQueryTable);
                    operStack.push(oper);
                    oper = Oper.SELECT;
                    break;
                case HiveParser.TOK_DROPTABLE:
                    tableNameStack.push(nowQueryTable);
                    operStack.push(oper);
                    oper = Oper.DROP;
                    break;
                case HiveParser.TOK_TRUNCATETABLE:
                    tableNameStack.push(nowQueryTable);
                    operStack.push(oper);
                    oper = Oper.TRUNCATE;
                    break;
                case HiveParser.TOK_LOAD:
                    tableNameStack.push(nowQueryTable);
                    operStack.push(oper);
                    oper = Oper.LOAD;
                    break;
                case HiveParser.TOK_CREATETABLE:
                    tableNameStack.push(nowQueryTable);
                    operStack.push(oper);
                    oper = Oper.CREATETABLE;
                    break;
            }
            if (ast.getToken() != null
                    && ast.getToken().getType() >= HiveParser.TOK_ALTERDATABASE_PROPERTIES
                    && ast.getToken().getType() <= HiveParser.TOK_ALTERVIEW_RENAME) {
                oper = Oper.ALTER;
            }
        }
    }

    // 解析下游节点
    private void endParseCurrentNode(ASTNode ast) {
        if (ast.getToken() != null) {
            switch (ast.getToken().getType()) {//join 从句结束,跳出join
                case HiveParser.TOK_RIGHTOUTERJOIN:
                case HiveParser.TOK_LEFTOUTERJOIN:
                case HiveParser.TOK_JOIN:
                    joinClause = false;
                    break;
                case HiveParser.TOK_QUERY:
                    break;
                case HiveParser.TOK_INSERT:
                case HiveParser.TOK_SELECT:
                    nowQueryTable = tableNameStack.pop();
                    oper = operStack.pop();
                    break;
            }
        }
    }

    // 解析当前节点
    private Set<String> parseCurrentNode(ASTNode ast, Set<String> set) {
        if (ast.getToken() != null) {
            switch (ast.getToken().getType()) {
                case HiveParser.TOK_TABLE_PARTITION:
                    if (ast.getChildCount() != 2) {
                        String table = BaseSemanticAnalyzer
                                .getUnescapedName((ASTNode) ast.getChild(0));
                        if (oper == Oper.SELECT) {
                            nowQueryTable = table;
                        }
                        tables.add(table);
                        tableOpers.add(table + OPER_SEPARATOR + oper);
                    }
                    break;
                case HiveParser.TOK_TAB:
                    String tableTab = BaseSemanticAnalyzer
                            .getUnescapedName((ASTNode) ast.getChild(0));
                    if (oper == Oper.SELECT) {
                        nowQueryTable = tableTab;
                    }
                    tables.add(tableTab);
                    tableOpers.add(tableTab + OPER_SEPARATOR + oper);
                    break;
                case HiveParser.TOK_TABREF:
                    ASTNode tabTree = (ASTNode) ast.getChild(0);
                    String tableName = (tabTree.getChildCount() == 1) ? BaseSemanticAnalyzer
                            .getUnescapedName((ASTNode) tabTree.getChild(0))
                            : BaseSemanticAnalyzer
                            .getUnescapedName((ASTNode) tabTree.getChild(0))
                            + "." + tabTree.getChild(1);
                    if (oper == Oper.SELECT) {
                        if (joinClause && !"".equals(nowQueryTable)) {
                            nowQueryTable += "&" + tableName;//
                        } else {
                            nowQueryTable = tableName;
                        }
                        set.add(tableName);
                    }
                    tables.add(tableName);
                    tableOpers.add(tableName + OPER_SEPARATOR + oper);
                    if (ast.getChild(1) != null) {
                        String alia = ast.getChild(1).getText().toLowerCase();
                        tableAlias.put(alia, tableName);
                    }
                    break;
                case HiveParser.TOK_TABLE_OR_COL:
                    if (ast.getParent().getType() != HiveParser.DOT) {
                        String col = ast.getChild(0).getText().toLowerCase();
                        if (tableAlias.get(col) == null
                                && colAlais.get(nowQueryTable + "." + col) == null) {
                            if (nowQueryTable.indexOf("&") > 0) {
                                cols.put(UNKNOWN + "." + col, "");
                            } else {
                                cols.put(nowQueryTable + "." + col, "");
                            }
                        }
                    }
                    break;
                case HiveParser.TOK_ALLCOLREF:
                    cols.put(nowQueryTable + ".*", "");
                    break;
                case HiveParser.TOK_SUBQUERY:
                    if (ast.getChildCount() == 2) {
                        String tableAlias = unescapeIdentifier(ast.getChild(1)
                                .getText());
                        String aliaReal = "";
                        for (String table : set) {
                            aliaReal += table + "&";
                        }
                        if (aliaReal.length() != 0) {
                            aliaReal = aliaReal.substring(0, aliaReal.length() - 1);
                        }
                        this.tableAlias.put(tableAlias, aliaReal);//sql6
                    }
                    break;
                case HiveParser.TOK_SELEXPR:
                    if (ast.getChild(0).getType() == HiveParser.TOK_TABLE_OR_COL) {
                        String column = ast.getChild(0).getChild(0).getText()
                                .toLowerCase();
                        if (nowQueryTable.indexOf("&") > 0) {
                            cols.put(UNKNOWN + "." + column, "");
                        } else if (colAlais.get(nowQueryTable + "." + column) == null) {
                            cols.put(nowQueryTable + "." + column, "");
                        }
                    } else if (ast.getChild(1) != null) {
                        String columnAlia = ast.getChild(1).getText().toLowerCase();
                        colAlais.put(nowQueryTable + "." + columnAlia, "");
                    }
                    break;
                case HiveParser.DOT:
                    if (ast.getType() == HiveParser.DOT) {
                        if (ast.getChildCount() == 2) {
                            if (ast.getChild(0).getType() == HiveParser.TOK_TABLE_OR_COL
                                    && ast.getChild(0).getChildCount() == 1
                                    && ast.getChild(1).getType() == HiveParser.Identifier) {
                                String alia = BaseSemanticAnalyzer
                                        .unescapeIdentifier(ast.getChild(0)
                                                .getChild(0).getText()
                                                .toLowerCase());
                                String column = BaseSemanticAnalyzer
                                        .unescapeIdentifier(ast.getChild(1)
                                                .getText().toLowerCase());
                                String realTable = null;
                                if (!tableOpers.contains(alia + OPER_SEPARATOR + oper)
                                        && tableAlias.get(alia) == null) {// [b SELECT, a
                                    // SELECT]
                                    tableAlias.put(alia, nowQueryTable);
                                }
                                if (tableOpers.contains(alia + OPER_SEPARATOR + oper)) {
                                    realTable = alia;
                                } else if (tableAlias.get(alia) != null) {
                                    realTable = tableAlias.get(alia);
                                }
                                if (realTable == null || realTable.length() == 0 || realTable.indexOf("&") > 0) {
                                    realTable = UNKNOWN;
                                }
                                cols.put(realTable + "." + column, "");
                            }
                        }
                    }
                    break;
                case HiveParser.TOK_ALTERTABLE_ADDPARTS:
                case HiveParser.TOK_ALTERTABLE_RENAME:
                case HiveParser.TOK_ALTERTABLE_ADDCOLS:
                    ASTNode alterTableName = (ASTNode) ast.getChild(0);
                    tables.add(alterTableName.getText());
                    tableOpers.add(alterTableName.getText()  + OPER_SEPARATOR + oper);
                    break;
            }
        }
        return set;
    }

    // 解析child节点
    private Set<String> parseChildNodes(ASTNode ast) {
        Set<String> set = new HashSet<>();
        int numCh = ast.getChildCount();
        if (numCh > 0) {
            for (int num = 0; num < numCh; num++) {
                ASTNode child = (ASTNode) ast.getChild(num);
                set.addAll(parseIteral(child));
            }
        }
        return set;
    }

    public static String unescapeIdentifier(String val) {
        if (val == null) {
            return null;
        }
        if (val.charAt(0) == '`' && val.charAt(val.length() - 1) == '`') {
            val = val.substring(1, val.length() - 1);
        }
        return val;
    }

    public static void main(String[] args) throws ParseException {
        String sql = "Select * from a.zpc1";
        sqlParser(sql);
    }
}

猜你喜欢

转载自blog.csdn.net/u012965373/article/details/83750759