keywordDict.py:
# -*- coding:utf-8 -*- # CREATED BY: bohuai jiang # CREATED ON: 2019/9/18 # LAST MODIFIED ON: # AIM: sort key words from service.sql_parser_graph.units import ParseUnit KEY_WEIGTH = {'SELECT': 10, 'INSERT': 10, 'WHERE': 8, 'FROM': 9, 'VALUES': 8, 'AND': 5, ',': 0} class Keywordstack: def __init__(self): self.value = None self.weight = -float('inf') self.length = 0 def insert(self, value: ParseUnit) -> None: self.length += 1 if value.name.upper() in KEY_WEIGTH.keys(): weight = KEY_WEIGTH[value.name.upper()] else: weight = -1 if weight > self.weight: self.value = value self.weight = weight def pop(self): return self.value def reset(self): self.value = None self.weight = -float('inf') self.length = 0 def is_empty(self): if self.length > 0: return False else: return True
SQLParser.py:
# created by bohuai jiang # on 2019/7/23 # last modified on 2019/9/17 10:14 # -*- coding: utf-8 -*- import sqlparse from sqlparse.sql import Where, IdentifierList, Identifier, TokenList, Token, Parenthesis, Comment, Case, Operation, \ Function, Values from service.sql_parser_graph.KeywrodDict import Keywordstack from service.sql_parser_graph.units import ParseUnitList from typing import Union, List, Optional,Tuple import re TAB_KEYWORD = ['FROM', 'LEFT JOIN', 'UPDATE', 'EXISTS', 'INNER JOIN', 'OUTER JOIN', 'JOIN', 'RIGHT JOIN', 'INTO'] COL_KEYWORD = ['INSERT', 'SELECT', 'WHERE', 'CASE', 'ON', 'AND', 'HAVING', 'OR', 'SET','WITH','BY','PRIOR'] IN_KEYWORD = ['IN'] ORDER_KEYWORD = ['ORDER BY','GROUP BY'] LIKE_KEYWORD = ['LIKE'] VALUE_KEYWORD = ['VALUES'] BETWEEN_KEYWORD = ['BETWEEN'] IS_KEYWORD = ['IS'] WHERE_EXCEPT = ['ROWNUMBER'] class SQLParser: def __init__(self, sql: str, **kwargs) -> None: self.exception_list = kwargs['exception'] if 'exception' in kwargs.keys() else [] self.has_where = False self._origin_sql = sql sql = self.sql_interpreter(sql) tokens = sqlparse.parse(sql) if len(tokens) > 1: raise Exception("sql is not single") else: self._stmt = tokens[0] self.re_get_elements() self._sql_text = sql def re_get_elements(self, where_only: bool = False): self.elements = ParseUnitList() self.lu_parse(self._stmt, add_opt=(not where_only)) self.elements.build_relation() return self.elements @property def tokens(self) -> Union[TokenList, Token]: return self._stmt def sql_statement(self): """ sql statement property :return: """ return self._stmt.get_type().upper().strip() def _is_function_contain_keyword(self, function: Token, keyword_list: List[str]) -> Optional[str]: if isinstance(function, Function): if function.tokens[0].value.upper() in keyword_list: return function.tokens[0].value.upper() return None def get_fist_keyword(self, statement: Token) -> str: keyword = '' for token in self.token_walk(statement, True, False): if token.is_keyword: keyword = token.value.upper() return keyword return keyword def lu_parse(self, statement: TokenList, level: int = 0, t_idx: int = 3, **kwargs) -> None: type_name = ['COL', 'TAB', 'SUB', 'IN', 'STRUCT'] add_opt = kwargs['add_opt'] if 'add_opt' in kwargs else True parents = kwargs['parents'] if 'parents' in kwargs else [] build_relation = kwargs['build_relation'] if 'build_relation' in kwargs else True keyword_capture = kwargs['keyword'] if 'keyword' in kwargs else '' is_where = kwargs['is_where'] if 'is_where' in kwargs else False order_by_loop = False if not isinstance(statement,TokenList): statement = [statement] for i, t in enumerate(statement): #print(i, t) v = self._is_function_contain_keyword(t, COL_KEYWORD + TAB_KEYWORD + IN_KEYWORD + ORDER_KEYWORD) if t.value in self.exception_list: t.is_keyword = False t.ttype = sqlparse.tokens.Name t = Identifier([t]) if (t.is_keyword or v is not None): keyword_capture = t.normalized if v is not None: keyword_capture = v if keyword_capture in COL_KEYWORD: t_idx = 0 elif keyword_capture in TAB_KEYWORD or 'JOIN' in keyword_capture: t_idx = 1 elif keyword_capture in IN_KEYWORD: t_idx = 3 elif keyword_capture in ORDER_KEYWORD: t_idx = 0 self.elements.add_order(tokens=t, key=keyword_capture, parents=parents, level=level, is_where=is_where) order_by_loop = True continue elif keyword_capture in BETWEEN_KEYWORD: t_idx = 0 self.elements.add_between(tokens=t, key=keyword_capture, parents=parents, level=level, is_where=is_where) continue elif keyword_capture in LIKE_KEYWORD: t_idx = 0 self.elements.add_like(tokens=t, key=keyword_capture, parents=parents, level=level, is_where=is_where) continue elif keyword_capture in IS_KEYWORD: t_idx = 0 self.elements.add_is(tokens=t, key=keyword_capture, parents=parents, level=level, is_where=is_where) continue else: t_idx = 4 # --- DATA Correction --- # if not t.is_whitespace and not isinstance(t, Comment) and 'Comment' not in str(t.ttype): # -- add operation --# if isinstance(t, Case) or isinstance(t, Operation) or order_by_loop or isinstance(t, Values): order_by_loop = False self.lu_parse(t, t_idx=t_idx, build_relation=True, add_opt=True, paretns=parents, level=level, keyword=keyword_capture) continue if isinstance(t, Where): count_valid = 0 for tt in self.token_walk(t, yield_current_token=False): if str(tt.ttype) == 'Token.Name' or tt.value.upper() in WHERE_EXCEPT: count_valid += 1 if count_valid > 0: self.has_where = True self.lu_parse(t, t_idx=t_idx, build_relation=True, add_opt=True, paretns=parents, level=level, is_where=True) continue if not isinstance(t, IdentifierList): # print(' value :', t) rest = self.elements.add(t, type_name[t_idx], parents=parents, key=keyword_capture, is_where=is_where, level=level) if t_idx == 3: t_idx = 0 if keyword_capture == 'INTO' and \ self.elements.by_id[len(self.elements.by_id)-1].type=='TAB': t_idx = 0 # print('after :', token_id,'\n') # -- subquery -- # if rest is not None: for rest_v in rest: sub_parents = rest_v['parents'] rest_tokens = rest_v['tokens'] if build_relation: self.elements.build_relation() for rest_t in rest_tokens: # print('sub_parent: ',token_id-1, 'sub_value: ', rest, '\n') if isinstance(rest_t, Parenthesis): first_keyword = self.get_fist_keyword(rest_t) level_ = level + 1 if 'SELECT' == first_keyword else level self.lu_parse(statement=rest_t, parents=sub_parents, t_idx=t_idx, add_opt=add_opt, build_relation=True, level=level_, is_where=is_where) else: self.lu_parse(t, t_idx=t_idx, build_relation=False, parents=parents, add_opt=add_opt, keyword=keyword_capture) def sql_reconstruct(self): units = self.elements.by_id() for id in units: pass def display_elements(self) -> None: for v in self.elements: print(v) # ---------# def get_table_name(self, alise_on = False) -> Union[Tuple[List[str],List[str]],List[str]]: tab_names = [] as_names = [] for tab in self.elements.by_type['TAB']: if '(' not in tab.name and 'DUAL' not in tab.name: tab_names.append(tab.name) if tab.as_name != 'DUMMY': as_names.append(tab.as_name) else: as_names.append(tab.name) if alise_on: return tab_names, as_names else: return tab_names def token_walk(self, token, topdown=True, yield_current_token=True): """ walk all token :param token: :param topdown: :return: """ def __has_next_token(t): return hasattr(t, "tokens") if yield_current_token: yield token for idx, tk in enumerate(token): if __has_next_token(tk) and topdown: for x in self.token_walk(tk, topdown, yield_current_token): yield x else: yield tk def sql_interpreter(self, sql: str) -> str: # -- 1. remove comments --# sql = sql.strip() ends = len(sql) for i in range(len(sql))[::-1]: if sql[i] not in ['\n', '\t', ' ', ';']: ends = i + 1 break sql = sql[0:ends] sql = sql.replace('(+)', '') sql = re.sub(re.compile("/\*.*?\*/", re.DOTALL), "", sql) sql = re.sub(r'(?m)^ *--.*\n?', '', sql) sql = re.sub('\s+', ' ', sql).strip() tokens = sqlparse.parse(sql) # -- 2. split sql --# sql = '' pre_keyword = None wirte = False for t in self.token_walk(tokens, True, False): if t.value == '(': if wirte: sql += ',' wirte = False sql += t.value.upper() else: sql += t.value.upper() if t.is_keyword: if t.value.upper() == 'INTO' and pre_keyword == 'INSERT': wirte = True pre_keyword = t.value.upper() return sql def reconstruct(self) -> str: n_line_elem = ['SELECT', 'FROM', 'WHERE', 'ORDER BY'] without_space = [',', '(', ')'] keys_list = sorted(self.elements.by_id.keys()) out = '' buffer = Keywordstack() for i in keys_list: unit = self.elements.by_id[i] if unit.type != 'SUB': level = unit.level value = self.elements.by_id[i].token.value if unit.type == 'STRUCT' and value not in [')', '(']: buffer.insert(unit) else: if not buffer.is_empty(): unit_ = buffer.pop() buffer.reset() level_ = unit.level value_ = unit_.name if value_ in n_line_elem: value_ = '\n' + '\t' * level_ + value_ if value_ in without_space: out = out[0:-1] if out[-1] == ' ' else out out += value_ else: out += value_ + ' ' if value in n_line_elem: value = '\n' + '\t' * level + value if value in without_space: out = out[0:-1] if out[-1] == ' ' else out out += value else: out += value + ' ' return out if __name__ == "__main__": address = '/home/yohoo/PycharmProjects/Common/' sql_text = ''' select `TYPE` as `type`, FILE_PATH as filePath, DISPLAY_NAME as `name`, SEQUENCE as sequence from bid_document where BID_ID = 3 and IS_VALID = 1 and rownum > 10000 order by SEQUENCE ''' sql_text = ''' select * from tab where rownum > 1000 ''' sp = SQLParser(sql_text, exception=['SORT']) sp.display_elements() out = sp.reconstruct() print(out) print(sp.has_where) print('table_names :', sp.get_table_name())
units.py:
# created by bohuai jiang # on 2019/7/23 # last modified on 2019/9/17 10:14 # -*- coding: utf-8 -*- from sqlparse.sql import Statement, Comment, Where, Identifier, IdentifierList, Parenthesis, Function, \ Comparison, Operation, Token, TokenList, Values from typing import Union, List, Tuple, Optional, Set class ParseUnit: def __init__(self): self.id = None self._name = None # sql code name self._as_name = None # as what name self._from_name = None # from where self._type = None # TAB-table , COL-column, SUB-subquery ,OPT- >,<,=.., FUNC-MAX,SUM.. self._keyword = None self._in_statement = 'OTHER' # self._opt = None self._parent = set() self._edges = set() self._level = 0 self.token = None @property def in_statement(self) -> str: return self._in_statement @property def level(self) -> int: return self._level @property def keyword(self) -> str: return self._keyword @property def name(self) -> str: return self._name @property def as_name(self) -> str: return self._as_name @property def from_name(self) -> str: return self._from_name @property def parent(self) -> set: return self._parent @property def type(self) -> str: return self._type @property def edges(self) -> set: return self._edges # @property # def opt(self) -> str: # return self._opt @keyword.setter def keyword(self, key: str): self._keyword = key.upper() @level.setter def level(self, level: int): self._level = level @name.setter def name(self, name: Optional[str]): if type(name) == str: self._name = name.upper() else: self._name = name @as_name.setter def as_name(self, as_name: str): self._as_name = as_name.upper() @from_name.setter def from_name(self, from_name: str): self._from_name = from_name.upper() @parent.setter def parent(self, parent: Set['ParseUnit']): self._parent = parent # @opt.setter # def opt(self, opt: str): # self._opt = opt @type.setter def type(self, type: str): if type not in ['COL', 'TAB', 'SUB', 'OPT', 'FUNC', 'STRUCT', 'VALUE']: raise ValueError('type must be either one of following [COL, TAB, SUB, OPT, FUNC, STRUC, VALUE]') self._type = type.upper() @in_statement.setter def in_statement(self, state: str): if state not in ['WHERE', 'ORTHER']: raise ValueError('type must be either one of following [WHERE, OTHER]') self._in_statement = state @edges.setter def edges(self, edges: Set['ParseUnit']): self._edges = edges def overwrite(self, unit: 'ParseUnit'): if unit.name is not None: self._name = unit.name if unit.as_name is not None: self._as_name = unit.as_name if unit.from_name is not None: self._from_name = unit.from_name if unit.parent is not None: self._parent = unit.parent if unit.type is not None: self._type = unit.type if not unit.edges: self._edges = unit.edges def inherit(self, unit: 'ParseUnit', update_edges: bool = False): self._name = unit.name self._as_name = unit.as_name if unit.from_name != 'DUMMY': self._from_name = unit.from_name self._type = unit.type if update_edges: self._edges.add(unit.id) def show(self) -> str: out = '' if self._from_name is not 'DUMMY' and not None: out += self._from_name + '.' out += self._name if self._as_name is not 'DUMMY' and not None: out += ' as ' + self._as_name return out def add_parents(self, parents: Union[List[int], Set[int]]) -> None: for p in parents: self._parent.add(p) def __repr__(self): out = '%s\n' % str(self.id) out += '\ttype:%s\n' % self.type out += '\tname:%s\n' % self.name out += '\tkeyword:%s\n' % self.keyword out += '\tstatement:%s\n' % self.in_statement out += '\tlevel:%s\n' % self.level out += '\tas_name:%s\n' % self.as_name out += '\tfrom' + (' tab ' if self.type == 'COL' else '') + ':%s\n' % self.from_name out += '\tparent:%s\n' % str(self.parent) out += '\tedges:%s\n' % str(self.edges) return out class ParseUnitList: def __init__(self) -> None: # -- tab col relation -- # self.by_type = {'COL': [], 'TAB': [], 'SUB': [], 'OPT': [], 'FUNC': [], 'STRUCT': [], 'VALUE': []} self.by_id = dict() # G self._allow_sub_has_table = False def __insert(self, unit: ParseUnit) -> int: # o(mn) m<n id = len(self.by_id) unit.id = id # for i, each_units in enumerate(self.by_type[unit.type]): # as_name = each_units.as_name # if unit.name == as_name and (unit.from_name == each_units.from_name \ # or each_units.from_name == 'DUMMY'): # unit.inherit(unit=each_units, update_edges=True) # each_units.inherit(unit=unit) # self.by_id[each_units.id] = each_units # break # -----# self.by_type[unit.type].append(unit) self.by_id[unit.id] = unit return id def __update_by_type(self) -> None: for key in ['SUB', 'TAB', 'OPT', 'FUNC', 'COL', 'STRUCT']: for unit in self.by_type[key]: self.by_id[unit.id] = unit def __update_by_id(self): self.by_type = {'COL': [], 'TAB': [], 'SUB': [], 'OPT': [], 'FUNC': [], 'STRUCT': [], 'VALUE': []} for id in self.by_id: unit = self.by_id[id] self.by_type[unit.type].append(unit) ######################################## # add function # ######################################## # ----------- add by token type -----------# def _add_Identifier(self, tokens: Token, type: str, key: str, level: int, is_where: bool, parents: List[int] = None) -> Tuple[int, Union[Token, TokenList]]: out = ParseUnit() if '(' in tokens.value and tokens.value != '(': out.type = 'SUB' else: out.type = type out.keyword = key out.level = level dot_flag = 1 out.token = tokens if is_where: out.in_statement = 'WHERE' if parents is not None and parents != []: out.add_parents(parents) abnormal = None try: for t in tokens: if str(t.ttype).upper() == 'TOKEN.PUNCTUATION' and t.value == '.': dot_flag += 1 continue if str(t.ttype).upper() == 'TOKEN.NAME': if dot_flag % 2 == 0: out.name = t.value dot_flag += 1 else: out.from_name = t.value if t.ttype is None: out.as_name = t.value if not isinstance(t, Identifier): abnormal = t if dot_flag <= 1: out.name = out.from_name out.from_name = 'DUMMY' except: out.name = tokens.value # --- double check whether used dot --- # if out.as_name is None: out.as_name = 'DUMMY' # -- patch --# if out.name is None: if abnormal is not None: out.name = abnormal.value else: out.name = out.as_name # -- add order by or group by -- # keyList = ['ORDER BY', 'GROUP BY'] if key in keyList: # -- find nearest opt -- # for id in range(len(self.by_id))[::-1]: acquire_id = id unit = self.by_id[id] if unit.type == 'OPT' and unit.name == key: break out.parent.add(acquire_id) # -- add to like -- # if key == 'LIKE': out.add_parents([len(self.by_id) - 1]) id = self.__insert(out) return id, abnormal def _add_Comparison(self, tokens: Comparison, type: str, key: str, level: int, is_where: bool, parents: List[int] = None) \ -> Optional[List[dict]]: # -- get opt unit --# opt = None for t in tokens: if str(t.ttype).upper() == 'TOKEN.OPERATOR.COMPARISON': opt = t.value unit = ParseUnit() unit.name = opt unit.type = 'OPT' unit.keyword = key unit.level = level count = 0 for t in tokens: if not t.is_whitespace: count += 1 if count == 2: unit.token = t break expect_id = len(self.by_id) + 1 if is_where: unit.in_statement = 'WHERE' if parents is not None and parents != []: unit.add_parents(parents) # -- left unit -- # parents = [expect_id] parents_token_left = self.add(tokens=tokens.left, type=type, key=key, level=level, parents=parents, is_where=is_where) self.__insert(unit) parents_token_right = self.add(tokens=tokens.right, type=type, key=key, level=level, parents=parents, is_where=is_where) # unit.edges.add(left_v) # unit.edges.add(right_v) if parents_token_left and parents_token_right: return parents_token_left + parents_token_right elif parents_token_left: return parents_token_left else: return parents_token_right def _add_Operation(self, tokens: Operation, type: str, key: str, level: int, is_where: bool, parents: List[int] = None): unit = ParseUnit() unit.name = tokens.value unit.type = 'OPT' unit.keyword = key unit.level = level unit.token = tokens if is_where: unit.in_statement = 'WHERE' if parents is not None and parents != []: unit.add_parents(parents) expect_id = len(self.by_id) self.__insert(unit) for t in tokens.tokens: self.add(tokens=t, type=type, key=key, level=level, parents=[expect_id], is_where=is_where) def _add_Function(self, tokens: Function, key: str, level: int, is_where: bool, parents: List[int] = None) \ -> Tuple[int, Optional[list]]: unit = ParseUnit() unit.name = tokens.tokens[0].value unit.type = 'FUNC' unit.keyword = key unit.level = level unit.token = tokens.tokens[0] if is_where: unit.in_statement = 'WHERE' if parents is not None and parents != []: unit.add_parents(parents) id = self.__insert(unit) return id, tokens.tokens[1::] def _add_Parenthesis(self, tokens: Parenthesis, key: str, level: int, is_where: bool, parents: List[int] = None) \ -> Tuple[int, Parenthesis]: unit = ParseUnit() unit.name = tokens.value unit.type = 'SUB' unit.keyword = key unit.level = level unit.from_name = 'DUMMY' unit.as_name = 'DUMMY' unit.token = tokens if is_where: unit.in_statement = 'WHERE' if parents is not None and parents != []: unit.add_parents(parents) id = self.__insert(unit) return id, tokens def add(self, tokens: Union[Token, TokenList], type: str, is_where: bool, key: str, level: int, \ parents: List[int] = None) -> Optional[List[dict]]: if isinstance(tokens, Identifier): id, abnormal = self._add_Identifier(tokens=tokens, type=type, parents=parents, key=key, level=level, is_where=is_where) if abnormal is not None: if isinstance(abnormal, Function): return self.add(tokens=abnormal, type=type, parents=[id], key=key, level=level, is_where=is_where) else: return [{'parents': [id], 'tokens': [abnormal]}] else: return None elif isinstance(tokens, Comparison): abnormal = self._add_Comparison(tokens=tokens, type=type, parents=parents, key=key, is_where=is_where, level=level) return abnormal elif isinstance(tokens, Function): id, token_list = self._add_Function(tokens=tokens, parents=parents, key=key, level=level, is_where=is_where) return [{'parents': [id], 'tokens': token_list}] elif isinstance(tokens, Parenthesis): id, token = self._add_Parenthesis(tokens=tokens, parents=parents, key=key, level=level, is_where=is_where) return [{'parents': [id], 'tokens': [token]}] elif isinstance(tokens, Values): rest = self._add_value(tokens=tokens, level=level, parents=parents, is_where=is_where) return rest elif isinstance(tokens, Operation): self._add_Operation(tokens=tokens, type=type, parents=parents, key=key, is_where=is_where, level=level) elif tokens.value.upper() == 'IN': self._add_In(tokens=tokens, is_where=is_where, key=key, level=level, parents=parents) else: type = 'STRUCT' if str(tokens.ttype[0]) not in ['Literal', 'Number'] else 'VALUE' id, token_list = self._add_Identifier(tokens=tokens, type=type, parents=parents, key=key, level=level, is_where=is_where) if token_list is not None: self.add(tokens=token_list, type=type, is_where=is_where, key=key, level=level, parents=[id]) return None # ----------- add by keywords ----------- # def _add_In(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> None: # acquire id cur_id = len(self.by_id) left_id = cur_id - 1 right_id = cur_id + 1 # --build in Node -# unit = ParseUnit() unit.name = 'IN' unit.type = 'OPT' unit.edges = {left_id, right_id} unit.keyword = key unit.level = level unit.token = tokens if is_where: unit.in_statement = 'WHERE' if parents is not None and parents != []: unit.add_parents(parents) self.__insert(unit) left = self.by_id[left_id] left.parent.add(cur_id) def add_order(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[ List[dict]]: next_id = len(self.by_id) + 1 unit = ParseUnit() unit.name = tokens.value unit.type = 'OPT' unit.keyword = key unit.level = level unit.token = tokens if is_where: unit.in_statement = 'WHERE' if parents is not None and parents != []: unit.add_parents(parents) unit.edges.add(next_id) self.__insert(unit) return None def add_like(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[ List[dict]]: pre_id = len(self.by_id) - 1 unit = ParseUnit() unit.name = tokens.value unit.type = 'OPT' unit.keyword = key unit.level = level unit.token = tokens if is_where: unit.in_statement = 'WHERE' if parents is not None and parents != []: unit.add_parents(parents) unit.edges.add(pre_id) self.__insert(unit) return None def add_between(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[ List[dict]]: id_pre = len(self.by_id) - 1 unit = ParseUnit() unit.name = 'BETWEEN' unit.type = 'OPT' unit.keyword = key unit.level = level unit.token = tokens if is_where: unit.in_statement = 'WHERE' if parents is not None and parents != []: unit.add_parents(parents) unit.edges.add(id_pre) # unit.edges.add(id_n_left) # unit.edges.add(id_n_right) self.__insert(unit) return None def _add_value(self, tokens: Token, level: int, is_where: bool, parents: List[int] = None) -> Optional[List[dict]]: self._allow_sub_has_table = True col_id = len(self.by_id) - 1 unit = ParseUnit() unit.name = 'VALUES' unit.type = 'OPT' unit.keyword = 'VALUES' unit.level = level unit.token = tokens if is_where: unit.in_statement = 'WHERE' if parents is not None and parents != []: unit.add_parents(parents) unit.edges = {col_id, col_id + 2} self.__insert(unit) out = [] for t in tokens.tokens[1::]: if isinstance(t, Parenthesis): p, tokens = self._add_Parenthesis(tokens=t, key='VALUES', level=level, parents=[col_id + 1], is_where=is_where) id = self.__insert(p) out.append({'parents': [id], 'tokens': [tokens]}) return out def add_is(self, tokens: Token, key: str, level: int, is_where: bool, parents: List[int] = None) -> Optional[ List[dict]]: pre_id = len(self.by_id) - 1 unit = ParseUnit() unit.name = tokens.value unit.type = 'OPT' unit.keyword = key unit.level = level unit.token = tokens if is_where: unit.in_statement = 'WHERE' if parents is not None and parents != []: unit.add_parents(parents) unit.edges.add(pre_id) self.__insert(unit) return None ######################################### def __iter__(self): return iter(self.by_id.values()) ######################################### # build relation function # ######################################### def build_relation(self): # --- build parents ---# symbol_idx = dict() # {as_name/id: [index]} idx_edges = dict() # {id : [index]} for key in self.by_id.keys(): idx_edges[key] = set() check_keys = ['COL'] if not self._allow_sub_has_table else ['COL', 'SUB'] # -- buil tab col relation --# for key in ['SUB', 'TAB', 'COL']: for unit in self.by_type[key]: key_i = unit.type # -- add edges -- # if len(unit.parent) > 0: for p in unit.parent: idx_edges[p].add(unit.id) if key_i == 'TAB': symbol = unit.as_name if symbol not in symbol_idx.keys(): symbol_idx[symbol] = [unit.id] else: symbol_idx[symbol].append(unit.id) if unit.name not in symbol_idx.keys(): symbol_idx[unit.name] = [unit.id] else: symbol_idx[unit.name].append(unit.id) # -- update parents --# if key_i in check_keys: if unit.from_name != 'DUMMY': try: parent_indexes = symbol_idx[unit.from_name] except: parent_indexes = [] # raise SQLGrammarError('invalid column: ' + unit.name) for parent in parent_indexes: unit.parent.add(parent) idx_edges[parent].add(unit.id) else: all_parents = self.add_all_parents(unit.level) if len(all_parents) == 1: parent = self.by_id[all_parents.pop()] as_name = parent.as_name if parent.as_name != 'DUMMY' else parent.name unit.from_name = as_name unit.add_parents(self.add_all_parents(unit.level)) for p in unit.parent: idx_edges[p].add(unit.id) self.__update_by_type() # --- build parents ---# between_count = None blevel = None b_id = None for id in self.by_id.keys(): unit = self.by_id[id] edges = unit.edges # -- between handler --# if between_count is not None: if blevel == unit.level: between_count += 1 unit.parent.add(b_id) if blevel == 3: between_count = None blevel = None b_id = None if unit.name == 'BETWEEN': between_count = 0 blevel = unit.level b_id = unit.id for ed in edges: self.by_id[ed].parent.add(id) # --- build edges --- # for id in self.by_id: parents = self.by_id[id].parent for pa in parents: self.by_id[pa].edges.add(id) self._allow_sub_has_table = False self.__update_by_id() def add_all_parents(self, level: int) -> Set[int]: parents = set() for key in ['TAB']: for unit in self.by_type[key]: if unit.level == level: parents.add(unit.id) return parents def build_relation_by_tab_info(self): pass ############################################ # graph search # ############################################ def find_root(self, graph: ParseUnit, col_only: bool = False) -> Optional[List[int]]: root = [] path = [] q = [graph.id] while q: v = q.pop(0) if not v in path: if type(v) == int: path = path + [v] units = self.by_id[v] if col_only: if units.type == 'COL' and '(' not in units.name: root.append(units.id) else: if len(units.edges) == 0: root.append(units.id) q = q + list(units.edges) return root def find_tab(self, colum: ParseUnit, tab_only: bool = False) -> Optional[List[int]]: tabs = [] path = [] q = [colum.id] while q: v = q.pop(0) if not v in path: path = path + [v] # --- # units = self.by_id[v] if tab_only: if units.type == 'TAB': if units.id not in tabs: tabs.append(units.id) else: if len(units.parent) == 0: if units.id not in tabs: tabs.append(units.id) # ---# q = q + list(units.parent) return tabs ############################ # remove node # ############################ def remove(self, id_list: List[int]): all_list = [] for id in id_list: trunk_id = self._get_remove_trunk(self.by_id[id]) all_list.extend(trunk_id) for id in all_list: del self.by_id[id] def _get_remove_trunk(self, unit: ParseUnit) -> List[int]: id_list = [] target_level = unit.level path = [] q = [unit.id] while q: v = q.pop(0) if not v in path: if type(v) == int: path = path + [v] # --- # units_ = self.by_id[v] c_level = units_.level if units_.type != 'TAB' and units_.type != 'SUB': id_list.append(units_.id) q = q + list(units_.parent) + list(units_.edges) else: if c_level > target_level: id_list.append(units_.id) return id_list class SQLGrammarError(Exception): pass