sql parser

keywordDict.py:

# -*- coding:utf-8 -*-
# CREATED BY: bohuai jiang 
# CREATED ON: 2019/9/18
# LAST MODIFIED ON:
# AIM: sort key words

from service.sql_parser_graph.units import ParseUnit

KEY_WEIGTH = {'SELECT': 10,
              'INSERT': 10,
              'WHERE': 8,
              'FROM': 9,
              'VALUES': 8,
              'AND': 5,
              ',': 0}


class Keywordstack:
    def __init__(self):
        self.value = None
        self.weight = -float('inf')
        self.length = 0

    def insert(self, value: ParseUnit) -> None:
        self.length += 1
        if value.name.upper() in KEY_WEIGTH.keys():
            weight = KEY_WEIGTH[value.name.upper()]
        else:
            weight = -1
        if weight > self.weight:
            self.value = value
            self.weight = weight

    def pop(self):
        return self.value

    def reset(self):
        self.value = None
        self.weight = -float('inf')
        self.length = 0

    def is_empty(self):
        if self.length > 0:
            return False
        else:
            return True

SQLParser.py:

# created by bohuai jiang
# on 2019/7/23
# last modified on 2019/9/17 10:14
# -*- coding: utf-8 -*-
import sqlparse
from sqlparse.sql import Where, IdentifierList, Identifier, TokenList, Token, Parenthesis, Comment, Case, Operation, \
    Function, Values
from service.sql_parser_graph.KeywrodDict import Keywordstack
from service.sql_parser_graph.units import ParseUnitList
from typing import Union, List, Optional,Tuple
import re

TAB_KEYWORD = ['FROM', 'LEFT JOIN', 'UPDATE', 'EXISTS', 'INNER JOIN', 'OUTER JOIN', 'JOIN', 'RIGHT JOIN', 'INTO']
COL_KEYWORD = ['INSERT', 'SELECT', 'WHERE', 'CASE', 'ON', 'AND', 'HAVING', 'OR', 'SET','WITH','BY','PRIOR']
IN_KEYWORD = ['IN']
ORDER_KEYWORD = ['ORDER BY','GROUP BY']
LIKE_KEYWORD = ['LIKE']
VALUE_KEYWORD = ['VALUES']
BETWEEN_KEYWORD = ['BETWEEN']
IS_KEYWORD = ['IS']

WHERE_EXCEPT = ['ROWNUMBER']


class SQLParser:

    def __init__(self, sql: str, **kwargs) -> None:
        self.exception_list = kwargs['exception'] if 'exception' in kwargs.keys() else []
        self.has_where = False
        self._origin_sql = sql
        sql = self.sql_interpreter(sql)
        tokens = sqlparse.parse(sql)
        if len(tokens) > 1:
            raise Exception("sql is not single")
        else:
            self._stmt = tokens[0]
        self.re_get_elements()
        self._sql_text = sql

    def re_get_elements(self, where_only: bool = False):
        self.elements = ParseUnitList()
        self.lu_parse(self._stmt, add_opt=(not where_only))
        self.elements.build_relation()
        return self.elements

    @property
    def tokens(self) -> Union[TokenList, Token]:
        return self._stmt

    def sql_statement(self):
        """
        sql statement property
        :return:
        """
        return self._stmt.get_type().upper().strip()

    def _is_function_contain_keyword(self, function: Token, keyword_list: List[str]) -> Optional[str]:
        if isinstance(function, Function):
            if function.tokens[0].value.upper() in keyword_list:
                return function.tokens[0].value.upper()
        return None

    def get_fist_keyword(self, statement: Token) -> str:
        keyword = ''
        for token in self.token_walk(statement, True, False):
            if token.is_keyword:
                keyword = token.value.upper()
                return keyword
        return keyword

    def lu_parse(self, statement: TokenList, level: int = 0, t_idx: int = 3, **kwargs) -> None:
        type_name = ['COL', 'TAB', 'SUB', 'IN', 'STRUCT']

        add_opt = kwargs['add_opt'] if 'add_opt' in kwargs else True
        parents = kwargs['parents'] if 'parents' in kwargs else []
        build_relation = kwargs['build_relation'] if 'build_relation' in kwargs else True
        keyword_capture = kwargs['keyword'] if 'keyword' in kwargs else ''
        is_where = kwargs['is_where'] if 'is_where' in kwargs else False
        order_by_loop = False
        if not isinstance(statement,TokenList):
            statement = [statement]
        for i, t in enumerate(statement):
            #print(i, t)
            v = self._is_function_contain_keyword(t, COL_KEYWORD + TAB_KEYWORD + IN_KEYWORD + ORDER_KEYWORD)
            if t.value in self.exception_list:
                t.is_keyword = False
                t.ttype = sqlparse.tokens.Name
                t = Identifier([t])
            if (t.is_keyword or v is not None):
                keyword_capture = t.normalized
                if v is not None:
                    keyword_capture = v

                if keyword_capture in COL_KEYWORD:
                    t_idx = 0
                elif keyword_capture in TAB_KEYWORD or 'JOIN' in keyword_capture:
                    t_idx = 1
                elif keyword_capture in IN_KEYWORD:
                    t_idx = 3
                elif keyword_capture in ORDER_KEYWORD:
                    t_idx = 0
                    self.elements.add_order(tokens=t, key=keyword_capture, parents=parents, level=level,
                                            is_where=is_where)
                    order_by_loop = True
                    continue
                elif keyword_capture in BETWEEN_KEYWORD:
                    t_idx = 0
                    self.elements.add_between(tokens=t, key=keyword_capture, parents=parents, level=level,
                                              is_where=is_where)
                    continue
                elif keyword_capture in LIKE_KEYWORD:
                    t_idx = 0
                    self.elements.add_like(tokens=t, key=keyword_capture, parents=parents, level=level,
                                           is_where=is_where)
                    continue
                elif keyword_capture in IS_KEYWORD:
                    t_idx = 0
                    self.elements.add_is(tokens=t, key=keyword_capture, parents=parents, level=level, is_where=is_where)
                    continue
                else:
                    t_idx = 4

            # --- DATA Correction --- #

            if not t.is_whitespace and not isinstance(t, Comment) and 'Comment' not in str(t.ttype):
                # -- add operation --#
                if isinstance(t, Case) or isinstance(t, Operation) or order_by_loop or isinstance(t, Values):
                    order_by_loop = False
                    self.lu_parse(t, t_idx=t_idx, build_relation=True, add_opt=True, paretns=parents, level=level,
                                  keyword=keyword_capture)
                    continue

                if isinstance(t, Where):
                    count_valid = 0
                    for tt in self.token_walk(t, yield_current_token=False):
                        if str(tt.ttype) == 'Token.Name' or tt.value.upper() in WHERE_EXCEPT:
                            count_valid += 1
                    if count_valid > 0:
                        self.has_where = True
                    self.lu_parse(t, t_idx=t_idx, build_relation=True, add_opt=True, paretns=parents, level=level,
                                  is_where=True)
                    continue
                if not isinstance(t, IdentifierList):
                    # print(' value :', t)
                    rest = self.elements.add(t, type_name[t_idx], parents=parents, key=keyword_capture,
                                             is_where=is_where, level=level)

                    if t_idx == 3:
                        t_idx = 0
                    if keyword_capture == 'INTO' and \
                            self.elements.by_id[len(self.elements.by_id)-1].type=='TAB':
                        t_idx = 0
                    # print('after :', token_id,'\n')
                    # -- subquery -- #
                    if rest is not None:
                        for rest_v in rest:
                            sub_parents = rest_v['parents']
                            rest_tokens = rest_v['tokens']
                            if build_relation:
                                self.elements.build_relation()
                            for rest_t in rest_tokens:
                                # print('sub_parent: ',token_id-1, 'sub_value: ', rest, '\n')
                                if isinstance(rest_t, Parenthesis):
                                    first_keyword = self.get_fist_keyword(rest_t)
                                    level_ = level + 1 if 'SELECT' == first_keyword else level
                                    self.lu_parse(statement=rest_t, parents=sub_parents, t_idx=t_idx,
                                                  add_opt=add_opt, build_relation=True, level=level_, is_where=is_where)
                else:
                    self.lu_parse(t, t_idx=t_idx, build_relation=False, parents=parents, add_opt=add_opt,
                                  keyword=keyword_capture)

    def sql_reconstruct(self):
        units = self.elements.by_id()
        for id in units:
            pass

    def display_elements(self) -> None:
        for v in self.elements:
            print(v)

    # ---------#
    def get_table_name(self, alise_on = False) -> Union[Tuple[List[str],List[str]],List[str]]:
        tab_names = []
        as_names = []
        for tab in self.elements.by_type['TAB']:
            if '(' not in tab.name and 'DUAL' not in tab.name:
                tab_names.append(tab.name)
                if tab.as_name != 'DUMMY':
                    as_names.append(tab.as_name)
                else:
                    as_names.append(tab.name)
        if alise_on:
            return tab_names, as_names
        else:
            return tab_names

    def token_walk(self, token, topdown=True, yield_current_token=True):
        """
        walk all token
        :param token:
        :param topdown:
        :return:
        """

        def __has_next_token(t):
            return hasattr(t, "tokens")

        if yield_current_token:
            yield token

        for idx, tk in enumerate(token):
            if __has_next_token(tk) and topdown:
                for x in self.token_walk(tk, topdown, yield_current_token):
                    yield x
            else:
                yield tk

    def sql_interpreter(self, sql: str) -> str:
        # -- 1. remove comments --#
        sql = sql.strip()
        ends = len(sql)
        for i in range(len(sql))[::-1]:
            if sql[i] not in ['\n', '\t', ' ', ';']:
                ends = i + 1
                break
        sql = sql[0:ends]
        sql = sql.replace('(+)', '')
        sql = re.sub(re.compile("/\*.*?\*/", re.DOTALL), "", sql)
        sql = re.sub(r'(?m)^ *--.*\n?', '', sql)
        sql = re.sub('\s+', ' ', sql).strip()
        tokens = sqlparse.parse(sql)
        # -- 2. split sql --#
        sql = ''
        pre_keyword = None
        wirte = False
        for t in self.token_walk(tokens, True, False):
            if t.value == '(':
                if wirte:
                    sql += ','
                    wirte = False
                sql += t.value.upper()

            else:
                sql += t.value.upper()
            if t.is_keyword:
                if t.value.upper() == 'INTO' and pre_keyword == 'INSERT':
                    wirte = True
                pre_keyword = t.value.upper()

        return sql

    def reconstruct(self) -> str:
        n_line_elem = ['SELECT', 'FROM', 'WHERE', 'ORDER BY']
        without_space = [',', '(', ')']
        keys_list = sorted(self.elements.by_id.keys())
        out = ''
        buffer = Keywordstack()
        for i in keys_list:
            unit = self.elements.by_id[i]
            if unit.type != 'SUB':
                level = unit.level
                value = self.elements.by_id[i].token.value
                if unit.type == 'STRUCT' and value not in  [')', '(']:
                    buffer.insert(unit)
                else:
                    if not buffer.is_empty():
                        unit_ = buffer.pop()
                        buffer.reset()
                        level_ = unit.level
                        value_ = unit_.name
                        if value_ in n_line_elem:
                            value_ = '\n' + '\t' * level_ + value_
                        if value_ in without_space:
                            out = out[0:-1] if out[-1] == ' ' else out
                            out += value_
                        else:
                            out += value_ + ' '

                    if value in n_line_elem:
                        value = '\n' + '\t' * level + value
                    if value in without_space:
                        out = out[0:-1] if out[-1] == ' ' else out
                        out += value
                    else:
                        out += value + ' '
        return out


if __name__ == "__main__":
    address = '/home/yohoo/PycharmProjects/Common/'
    sql_text = '''
        select
        `TYPE` as `type`,
        FILE_PATH as filePath,
        DISPLAY_NAME as `name`,
        SEQUENCE as sequence
        from bid_document
        where  BID_ID = 3 and IS_VALID = 1 and rownum > 10000
        order by SEQUENCE 
                    '''
    sql_text = '''
    select * from tab where rownum > 1000
    '''
    sp = SQLParser(sql_text, exception=['SORT'])

    sp.display_elements()
    out = sp.reconstruct()
    print(out)
    print(sp.has_where)
    print('table_names :', sp.get_table_name())

units.py:

# created by bohuai jiang
# on 2019/7/23
# last modified on 2019/9/17 10:14
# -*- coding: utf-8 -*-

from sqlparse.sql import Statement, Comment, Where, Identifier, IdentifierList, Parenthesis, Function, \
    Comparison, Operation, Token, TokenList, Values
from typing import Union, List, Tuple, Optional, Set


class ParseUnit:
    def __init__(self):
        self.id = None
        self._name = None  # sql code name
        self._as_name = None  # as what name
        self._from_name = None  # from where
        self._type = None  # TAB-table , COL-column, SUB-subquery ,OPT- >,<,=.., FUNC-MAX,SUM..
        self._keyword = None
        self._in_statement = 'OTHER'

        # self._opt = None
        self._parent = set()
        self._edges = set()
        self._level = 0
        self.token = None

    @property
    def in_statement(self) -> str:
        return self._in_statement

    @property
    def level(self) -> int:
        return self._level

    @property
    def keyword(self) -> str:
        return self._keyword

    @property
    def name(self) -> str:
        return self._name

    @property
    def as_name(self) -> str:
        return self._as_name

    @property
    def from_name(self) -> str:
        return self._from_name

    @property
    def parent(self) -> set:
        return self._parent

    @property
    def type(self) -> str:
        return self._type

    @property
    def edges(self) -> set:
        return self._edges

    # @property
    # def opt(self) -> str:
    #     return self._opt
    @keyword.setter
    def keyword(self, key: str):
        self._keyword = key.upper()

    @level.setter
    def level(self, level: int):
        self._level = level

    @name.setter
    def name(self, name: Optional[str]):
        if type(name) == str:
            self._name = name.upper()
        else:
            self._name = name

    @as_name.setter
    def as_name(self, as_name: str):
        self._as_name = as_name.upper()

    @from_name.setter
    def from_name(self, from_name: str):
        self._from_name = from_name.upper()

    @parent.setter
    def parent(self, parent: Set['ParseUnit']):
        self._parent = parent

    # @opt.setter
    # def opt(self, opt: str):
    #     self._opt = opt

    @type.setter
    def type(self, type: str):
        if type not in ['COL', 'TAB', 'SUB', 'OPT', 'FUNC', 'STRUCT', 'VALUE']:
            raise ValueError('type must be either one of following [COL, TAB, SUB, OPT, FUNC, STRUC, VALUE]')
        self._type = type.upper()

    @in_statement.setter
    def in_statement(self, state: str):
        if state not in ['WHERE', 'ORTHER']:
            raise ValueError('type must be either one of following [WHERE, OTHER]')
        self._in_statement = state

    @edges.setter
    def edges(self, edges: Set['ParseUnit']):
        self._edges = edges

    def overwrite(self, unit: 'ParseUnit'):
        if unit.name is not None:
            self._name = unit.name
        if unit.as_name is not None:
            self._as_name = unit.as_name
        if unit.from_name is not None:
            self._from_name = unit.from_name
        if unit.parent is not None:
            self._parent = unit.parent
        if unit.type is not None:
            self._type = unit.type
        if not unit.edges:
            self._edges = unit.edges

    def inherit(self, unit: 'ParseUnit', update_edges: bool = False):
        self._name = unit.name
        self._as_name = unit.as_name
        if unit.from_name != 'DUMMY':
            self._from_name = unit.from_name
        self._type = unit.type
        if update_edges:
            self._edges.add(unit.id)

    def show(self) -> str:
        out = ''
        if self._from_name is not 'DUMMY' and not None:
            out += self._from_name + '.'
        out += self._name
        if self._as_name is not 'DUMMY' and not None:
            out += ' as ' + self._as_name
        return out

    def add_parents(self, parents: Union[List[int], Set[int]]) -> None:
        for p in parents:
            self._parent.add(p)

    def __repr__(self):
        out = '%s\n' % str(self.id)
        out += '\ttype:%s\n' % self.type
        out += '\tname:%s\n' % self.name
        out += '\tkeyword:%s\n' % self.keyword
        out += '\tstatement:%s\n' % self.in_statement
        out += '\tlevel:%s\n' % self.level
        out += '\tas_name:%s\n' % self.as_name
        out += '\tfrom' + (' tab ' if self.type == 'COL' else '') + ':%s\n' % self.from_name
        out += '\tparent:%s\n' % str(self.parent)
        out += '\tedges:%s\n' % str(self.edges)

        return out


class ParseUnitList:
    def __init__(self) -> None:
        # -- tab col relation -- #
        self.by_type = {'COL': [],
                        'TAB': [],
                        'SUB': [],
                        'OPT': [],
                        'FUNC': [],
                        'STRUCT': [],
                        'VALUE': []}
        self.by_id = dict()  # G
        self._allow_sub_has_table = False

    def __insert(self, unit: ParseUnit) -> int:
        # o(mn) m<n
        id = len(self.by_id)
        unit.id = id
        # for i, each_units in enumerate(self.by_type[unit.type]):
        #     as_name = each_units.as_name
        #     if unit.name == as_name and (unit.from_name == each_units.from_name \
        #                                  or each_units.from_name == 'DUMMY'):
        #         unit.inherit(unit=each_units, update_edges=True)
        #         each_units.inherit(unit=unit)
        #         self.by_id[each_units.id] = each_units
        #         break
        # -----#
        self.by_type[unit.type].append(unit)
        self.by_id[unit.id] = unit
        return id

    def __update_by_type(self) -> None:
        for key in ['SUB', 'TAB', 'OPT', 'FUNC', 'COL', 'STRUCT']:
            for unit in self.by_type[key]:
                self.by_id[unit.id] = unit

    def __update_by_id(self):
        self.by_type = {'COL': [],
                        'TAB': [],
                        'SUB': [],
                        'OPT': [],
                        'FUNC': [],
                        'STRUCT': [],
                        'VALUE': []}
        for id in self.by_id:
            unit = self.by_id[id]
            self.by_type[unit.type].append(unit)

    ########################################
    #           add  function              #
    ########################################

    # ----------- add by token type -----------#

    def _add_Identifier(self, tokens: Token, type: str, key: str, level: int, is_where: bool,
                        parents: List[int] = None) -> Tuple[int, Union[Token, TokenList]]:
        out = ParseUnit()
        if '(' in tokens.value and tokens.value != '(':
            out.type = 'SUB'
        else:
            out.type = type
        out.keyword = key
        out.level = level
        dot_flag = 1
        out.token = tokens
        if is_where:
            out.in_statement = 'WHERE'
        if parents is not None and parents != []:
            out.add_parents(parents)

        abnormal = None
        try:
            for t in tokens:
                if str(t.ttype).upper() == 'TOKEN.PUNCTUATION' and t.value == '.':
                    dot_flag += 1
                    continue
                if str(t.ttype).upper() == 'TOKEN.NAME':
                    if dot_flag % 2 == 0:
                        out.name = t.value
                        dot_flag += 1
                    else:
                        out.from_name = t.value
                if t.ttype is None:
                    out.as_name = t.value
                    if not isinstance(t, Identifier):
                        abnormal = t
            if dot_flag <= 1:
                out.name = out.from_name
                out.from_name = 'DUMMY'
        except:
            out.name = tokens.value
        # --- double check whether used  dot --- #

        if out.as_name is None:
            out.as_name = 'DUMMY'

        # -- patch --#
        if out.name is None:
            if abnormal is not None:
                out.name = abnormal.value
            else:
                out.name = out.as_name
        # -- add  order by or group by -- #
        keyList = ['ORDER BY', 'GROUP BY']

        if key in keyList:
            # -- find nearest opt -- #
            for id in range(len(self.by_id))[::-1]:
                acquire_id = id
                unit = self.by_id[id]
                if unit.type == 'OPT' and unit.name == key:
                    break
            out.parent.add(acquire_id)

        # -- add to like -- #
        if key == 'LIKE':
            out.add_parents([len(self.by_id) - 1])
        id = self.__insert(out)
        return id, abnormal

    def _add_Comparison(self, tokens: Comparison, type: str, key: str, level: int, is_where: bool,
                        parents: List[int] = None) \
            -> Optional[List[dict]]:
        # -- get opt unit --#
        opt = None
        for t in tokens:
            if str(t.ttype).upper() == 'TOKEN.OPERATOR.COMPARISON':
                opt = t.value
        unit = ParseUnit()
        unit.name = opt
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        count = 0
        for t in tokens:
            if not t.is_whitespace:
                count += 1
            if count == 2:
                unit.token = t
                break
        expect_id = len(self.by_id) + 1
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)

        # -- left unit -- #
        parents = [expect_id]
        parents_token_left = self.add(tokens=tokens.left, type=type, key=key, level=level, parents=parents,
                                      is_where=is_where)

        self.__insert(unit)

        parents_token_right = self.add(tokens=tokens.right, type=type, key=key, level=level, parents=parents,
                                       is_where=is_where)

        # unit.edges.add(left_v)
        # unit.edges.add(right_v)

        if parents_token_left and parents_token_right:
            return parents_token_left + parents_token_right
        elif parents_token_left:
            return parents_token_left
        else:
            return parents_token_right

    def _add_Operation(self, tokens: Operation, type: str, key: str, level: int, is_where: bool,
                       parents: List[int] = None):
        unit = ParseUnit()
        unit.name = tokens.value
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        expect_id = len(self.by_id)
        self.__insert(unit)
        for t in tokens.tokens:
            self.add(tokens=t, type=type, key=key, level=level, parents=[expect_id],
                     is_where=is_where)

    def _add_Function(self, tokens: Function, key: str, level: int, is_where: bool, parents: List[int] = None) \
            -> Tuple[int, Optional[list]]:
        unit = ParseUnit()
        unit.name = tokens.tokens[0].value
        unit.type = 'FUNC'
        unit.keyword = key
        unit.level = level
        unit.token = tokens.tokens[0]
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        id = self.__insert(unit)
        return id, tokens.tokens[1::]

    def _add_Parenthesis(self, tokens: Parenthesis, key: str, level: int, is_where: bool, parents: List[int] = None) \
            -> Tuple[int, Parenthesis]:
        unit = ParseUnit()
        unit.name = tokens.value
        unit.type = 'SUB'
        unit.keyword = key
        unit.level = level
        unit.from_name = 'DUMMY'
        unit.as_name = 'DUMMY'
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        id = self.__insert(unit)
        return id, tokens

    def add(self, tokens: Union[Token, TokenList], type: str, is_where: bool, key: str, level: int, \
            parents: List[int] = None) -> Optional[List[dict]]:

        if isinstance(tokens, Identifier):
            id, abnormal = self._add_Identifier(tokens=tokens, type=type, parents=parents, key=key, level=level,
                                                is_where=is_where)
            if abnormal is not None:
                if isinstance(abnormal, Function):
                    return self.add(tokens=abnormal, type=type, parents=[id], key=key, level=level,
                                    is_where=is_where)
                else:
                    return [{'parents': [id], 'tokens': [abnormal]}]
            else:
                return None
        elif isinstance(tokens, Comparison):
            abnormal = self._add_Comparison(tokens=tokens, type=type, parents=parents, key=key, is_where=is_where,
                                            level=level)
            return abnormal
        elif isinstance(tokens, Function):
            id, token_list = self._add_Function(tokens=tokens, parents=parents, key=key, level=level,
                                                is_where=is_where)
            return [{'parents': [id], 'tokens': token_list}]

        elif isinstance(tokens, Parenthesis):
            id, token = self._add_Parenthesis(tokens=tokens, parents=parents, key=key, level=level,
                                              is_where=is_where)
            return [{'parents': [id], 'tokens': [token]}]
        elif isinstance(tokens, Values):
            rest = self._add_value(tokens=tokens, level=level, parents=parents, is_where=is_where)
            return rest
        elif isinstance(tokens, Operation):
            self._add_Operation(tokens=tokens, type=type, parents=parents, key=key, is_where=is_where, level=level)
        elif tokens.value.upper() == 'IN':
            self._add_In(tokens=tokens, is_where=is_where, key=key, level=level, parents=parents)
        else:
            type = 'STRUCT' if str(tokens.ttype[0]) not in ['Literal', 'Number'] else 'VALUE'
            id, token_list = self._add_Identifier(tokens=tokens, type=type, parents=parents, key=key, level=level,
                                                  is_where=is_where)
            if token_list is not None:
                self.add(tokens=token_list, type=type, is_where=is_where, key=key, level=level, parents=[id])

        return None

    # ----------- add by keywords ----------- #
    def _add_In(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> None:
        # acquire id
        cur_id = len(self.by_id)
        left_id = cur_id - 1
        right_id = cur_id + 1
        # --build in Node -#
        unit = ParseUnit()
        unit.name = 'IN'
        unit.type = 'OPT'
        unit.edges = {left_id, right_id}
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        self.__insert(unit)
        left = self.by_id[left_id]
        left.parent.add(cur_id)

    def add_order(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[
        List[dict]]:
        next_id = len(self.by_id) + 1
        unit = ParseUnit()
        unit.name = tokens.value
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        unit.edges.add(next_id)
        self.__insert(unit)
        return None

    def add_like(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[
        List[dict]]:
        pre_id = len(self.by_id) - 1
        unit = ParseUnit()
        unit.name = tokens.value
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        unit.edges.add(pre_id)
        self.__insert(unit)
        return None

    def add_between(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[
        List[dict]]:
        id_pre = len(self.by_id) - 1
        unit = ParseUnit()
        unit.name = 'BETWEEN'
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        unit.edges.add(id_pre)
        # unit.edges.add(id_n_left)
        # unit.edges.add(id_n_right)
        self.__insert(unit)
        return None

    def _add_value(self, tokens: Token, level: int, is_where: bool, parents: List[int] = None) -> Optional[List[dict]]:
        self._allow_sub_has_table = True
        col_id = len(self.by_id) - 1
        unit = ParseUnit()
        unit.name = 'VALUES'
        unit.type = 'OPT'
        unit.keyword = 'VALUES'
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        unit.edges = {col_id, col_id + 2}
        self.__insert(unit)

        out = []
        for t in tokens.tokens[1::]:
            if isinstance(t, Parenthesis):
                p, tokens = self._add_Parenthesis(tokens=t, key='VALUES', level=level, parents=[col_id + 1],
                                                  is_where=is_where)
                id = self.__insert(p)
                out.append({'parents': [id], 'tokens': [tokens]})
        return out

    def add_is(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[
        List[dict]]:
        pre_id = len(self.by_id) - 1
        unit = ParseUnit()
        unit.name = tokens.value
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        unit.edges.add(pre_id)
        self.__insert(unit)
        return None

    #########################################

    def __iter__(self):
        return iter(self.by_id.values())

    #########################################
    #          build relation function      #
    #########################################

    def build_relation(self):
        # --- build parents ---#
        symbol_idx = dict()  # {as_name/id: [index]}
        idx_edges = dict()  # {id : [index]}
        for key in self.by_id.keys():
            idx_edges[key] = set()

        check_keys = ['COL'] if not self._allow_sub_has_table else ['COL', 'SUB']

        # -- buil tab col relation --#
        for key in ['SUB', 'TAB', 'COL']:
            for unit in self.by_type[key]:
                key_i = unit.type
                # -- add edges -- #
                if len(unit.parent) > 0:
                    for p in unit.parent:
                        idx_edges[p].add(unit.id)

                if key_i == 'TAB':
                    symbol = unit.as_name
                    if symbol not in symbol_idx.keys():
                        symbol_idx[symbol] = [unit.id]
                    else:
                        symbol_idx[symbol].append(unit.id)

                    if unit.name not in symbol_idx.keys():
                        symbol_idx[unit.name] = [unit.id]
                    else:
                        symbol_idx[unit.name].append(unit.id)
                # -- update parents --#
                if key_i in check_keys:
                    if unit.from_name != 'DUMMY':
                        try:
                            parent_indexes = symbol_idx[unit.from_name]
                        except:
                            parent_indexes = []
                            # raise SQLGrammarError('invalid column: ' + unit.name)
                        for parent in parent_indexes:
                            unit.parent.add(parent)
                            idx_edges[parent].add(unit.id)
                    else:
                        all_parents = self.add_all_parents(unit.level)
                        if len(all_parents) == 1:
                            parent = self.by_id[all_parents.pop()]
                            as_name = parent.as_name if parent.as_name != 'DUMMY' else parent.name
                            unit.from_name = as_name
                        unit.add_parents(self.add_all_parents(unit.level))
                        for p in unit.parent:
                            idx_edges[p].add(unit.id)
        self.__update_by_type()
        # --- build parents ---#
        between_count = None
        blevel = None
        b_id = None
        for id in self.by_id.keys():
            unit = self.by_id[id]
            edges = unit.edges
            # -- between handler --#
            if between_count is not None:
                if blevel == unit.level:
                    between_count += 1
                    unit.parent.add(b_id)
                if blevel == 3:
                    between_count = None
                    blevel = None
                    b_id = None
            if unit.name == 'BETWEEN':
                between_count = 0
                blevel = unit.level
                b_id = unit.id

            for ed in edges:
                self.by_id[ed].parent.add(id)
        # --- build edges --- #
        for id in self.by_id:
            parents = self.by_id[id].parent
            for pa in parents:
                self.by_id[pa].edges.add(id)
        self._allow_sub_has_table = False
        self.__update_by_id()

    def add_all_parents(self, level: int) -> Set[int]:
        parents = set()
        for key in ['TAB']:
            for unit in self.by_type[key]:
                if unit.level == level:
                    parents.add(unit.id)
        return parents

    def build_relation_by_tab_info(self):
        pass

    ############################################
    #              graph search                #
    ############################################
    def find_root(self, graph: ParseUnit, col_only: bool = False) -> Optional[List[int]]:
        root = []
        path = []
        q = [graph.id]
        while q:
            v = q.pop(0)
            if not v in path:
                if type(v) == int:
                    path = path + [v]
                    units = self.by_id[v]
                    if col_only:
                        if units.type == 'COL' and '(' not in units.name:
                            root.append(units.id)
                    else:
                        if len(units.edges) == 0:
                            root.append(units.id)
                    q = q + list(units.edges)
        return root

    def find_tab(self, colum: ParseUnit, tab_only: bool = False) -> Optional[List[int]]:
        tabs = []
        path = []
        q = [colum.id]
        while q:
            v = q.pop(0)
            if not v in path:
                path = path + [v]
                # --- #
                units = self.by_id[v]
                if tab_only:
                    if units.type == 'TAB':
                        if units.id not in tabs:
                            tabs.append(units.id)
                else:
                    if len(units.parent) == 0:
                        if units.id not in tabs:
                            tabs.append(units.id)
                # ---#
                q = q + list(units.parent)
        return tabs

    ############################
    #        remove node       #
    ############################

    def remove(self, id_list: List[int]):
        all_list = []
        for id in id_list:
            trunk_id = self._get_remove_trunk(self.by_id[id])
            all_list.extend(trunk_id)

        for id in all_list:
            del self.by_id[id]

    def _get_remove_trunk(self, unit: ParseUnit) -> List[int]:
        id_list = []
        target_level = unit.level
        path = []
        q = [unit.id]
        while q:
            v = q.pop(0)
            if not v in path:
                if type(v) == int:
                    path = path + [v]
                    # --- #
                    units_ = self.by_id[v]
                    c_level = units_.level
                    if units_.type != 'TAB' and units_.type != 'SUB':
                        id_list.append(units_.id)
                        q = q + list(units_.parent) + list(units_.edges)
                    else:
                        if c_level > target_level:
                            id_list.append(units_.id)
        return id_list


class SQLGrammarError(Exception):
    pass

猜你喜欢

转载自www.cnblogs.com/applejuice/p/11672393.html