python练习---模拟sql

要求:实现模拟sql登录,增删改查功能

#_*_conding:utf-8_*_
import os

#登录验证部分

status_dic = {
    'username': None,     #设置一个登录用户
    'status': False,      #用于后面验证是否登录,登录状态
}

def wrapper(func):  #登录验证的装饰器
    def inner1():
        if status_dic['status']:
            ret = func()
            return ret
        else:
            print("请先进行登录")
            if login():
                ret = func()
                return ret
    return inner1






def login(*args):   #登录模块
    with open("login.txt", encoding="utf-8", mode="r") as f:
        login_list = f.readlines()
    if args:                          #判断用户是否登录
        status_dic['username'] = args[0]
        status_dic['status'] = True
        return True
    else:
        count = 0
        while count < 3:
            if status_dic['status'] == False:
                username = input("请输入登录用户名:").strip()
                for i in login_list:
                    if username == eval(i)["name"]:     #对用户名进行检查,看是否存在,不存在直接走for循环的else判断,让用户进行注册

                        password = input("请登录用户密码:").strip()
                        if password == eval(i)["password"]:  #判断密码输入是否正确

                            print("login successfully")
                            status_dic['username'] = username
                            status_dic['status'] = True
                            return sql_main()
                        else:
                            print("您还有%s次机会" % int(2 - count))  #3次机会验证,都错误会返回登录首页
                            count += 1
                            break

                else:
                    print("检查用户名未注册...")
                    y = input("是否注册(yes/no):").strip()
                    if y == "yes" or y =="y":
                        return register()

                    elif y == "no" or y == "n":
                        return home_page()
                    else:
                        print("输入错误...")
        else:
            print("错误太多次啦...")


def register():     #注册模块
    tag = True
    f1 = open("login.txt", encoding="utf-8", mode="r")
    login_list = f1.readlines()
    print("账号注册".center(50,"#"))
    while tag:
        username = input("请输入您注册的用户名:").strip()
        for i in login_list:
            if username == eval(i)["name"]:
                print("用户名已经存在,请重新输入")
                f1.close()
                break
        else:
            password = input("请输入您注册的密码:").strip()
            f2 = open("login.txt",encoding="utf-8",mode="a")
            f2.write("{}\n".format({"name": username,"password": password}))
            tag = False
            print("注册成功")
            f2.close()
            return login(username, password)



def log_out():
    print("%s用户已注销" % status_dic['username'])
    status_dic['username'] = None
    status_dic['status'] = False


def Quit_exit():
    print("退出程序...")
    return exit()




def home_page():
    dic = {
        1: login,
        2: register,
        3: sql_main,
        4: log_out,
        5: Quit_exit
    }

    while True:
        print('1:登录用户\n2:注册用户\n3:执行sql\n4:注销\n5:退出程序')

        bianhao = input("请输入选择的序列号:").strip()
        if bianhao.isdigit():
            bianhao = int(bianhao)
            if 0 < bianhao <= 8:
                dic[bianhao]()
            else:
                print("输入编号不在范围内")
        else:
            print("输入的序列号只能是数字")


#sql主程序,只有登录成功后才可以执行sql语句
s = '''
支持的sql语法
1:insert
  insert into user values 张三,20,13412345678,BOSS
2:delete
  delete from user where id =6
3:update
  update user set name = Tom where id > 5
4:select
  select * from user
  select name,age from user where age > 2
  select * from user where job='IT'
  select * from user where phone like 133
'''
@wrapper
def sql_main():
    print('\033[0;31m%s\033[0m' % s)
    while True:
        sql = input("sql>").strip()
        if sql == 'exit': break
        if len(sql) == 0: continue

        sql_dic = sql_parse(sql)
        # print('main res is %s' % sql_dic)

        if len(sql_dic) == 0:
            print('sql语法不正确')
            continue
        res = sql_action(sql_dic)
        # print('\033[41;1m%s\033[0m' % res[-1])

        if type(res) is list:
            try:
                for i in res[-1]:
                    if type(i) is str:
                        print(i.strip())
                    else:
                        print(i)
                print("\n共计查询出%s条数据" % len(res[1]))
            except TypeError:
                pass
        else:
            print(res)


#sql处理部分
#第一部分:sql解析
def sql_parse(sql):  #insert delete update select
    '''
    把sql字符串切分,提取命令信息,分发给具体的解析函数去解析
    :param sql:
    :return:
    '''

    parse_func = {
        'insert': insert_parse,
        'delete': delete_parse,
        'update': update_parse,
        'select': select_parse
    }

    #print('sql str is %s' % sql)
    sql_l = sql.split()
    func = sql_l[0]
    res = ''
    if func in parse_func:
        res = parse_func[func](sql_l)

    return res



def insert_parse(sql_l):
    '''
    定义insert语句的语法结构,执行sql解析操作,返回sql_dic
    :param sql:
    :return:
    '''
    sql_dic = {
        'func': insert,
        'insert': [],
        'into': [],
        'values': []
    }
    return handle_parse(sql_l,sql_dic)

def delete_parse(sql_l):
    '''
    定义delete语句的语法结构,执行sql解析操作,返回sql_dic
    :param sql:
    :return:
    '''
    sql_dic = {
        'func': delete,
        'delete': [],
        'from': [],
        'where': []
    }
    return handle_parse(sql_l,sql_dic)

def update_parse(sql_l):
    '''
    定义update语句的语法结构,执行sql解析操作,返回sql_dic
    :param sql:
    :return:
    '''
    sql_dic = {
        'func': update,
        'update': [],
        'set': [],
        'where': []
    }
    return handle_parse(sql_l, sql_dic)

def select_parse(sql_l):
    '''
    定义select语句的语法结构,执行sql解析操作,返回sql_dic
    :param sql:
    :return:
    '''
    # print('from in the select_parse \033[41;1m%s\033[0m' % sql_l)
    sql_dic = {
        'func':select,
        'select': [],  #查询字段
        'from': [],    #数据库.表
        'where': [],   #过滤条件
        'limit': []    #limit条件
    }
    return handle_parse(sql_l,sql_dic)



def handle_parse(sql_l,sql_dic):
    '''
    执行sql解析操作,返回sql_dic
    :param sql_l:
    :param sql_dic:
    :return:
    '''
    # print('sql_l is \033[40;1m%s033[0m \nsql_dic is \033[39;1m%s033[0m' %(sql_l,sql_dic))
    tag = False
    for i in sql_l:
        if tag and i in sql_dic:
            tag = False
        if not tag and i in sql_dic:
            tag = True
            key = i
            continue
        if tag:
            sql_dic[key].append(i)
    if sql_dic.get('where'):
        sql_dic['where'] = where_parse(sql_dic.get('where'))   #['id>4','and','id<10']
    # print('from in the handle_parse sql_dic is \033[40;1m%s\033[0m' % sql_dic)
    return sql_dic


def where_parse(where_l):   #['id>','4','and','id','<10']  -->  ['id>4','and','id<10']
    res = []
    key = ['and','or','not']
    char = ''
    for i in where_l:
        if len(i) == 0:continue
        if i in key:  #i为key当中存放的逻辑运算符
            if len(char) != 0:
                char = three_parse(char)
                res.append(char)  #char = 'id>4' ---> char = '['id','>','4']'
            res.append(i)
            char = ''
        else:

            char += i  #'id>4'
    else:
        char =  three_parse(char)
        res.append(char)
    #['id>4','and','id<=10'] --> [['id','>','4'],'and',['id','<=','10']]
    # print('from in the where_parse res is \033[40;1m%s\033[0m' % res)

    return res

def three_parse(str_l):
    #'id<=10'  --->  '['id','<=','10']'
    key = ['>','<','=']
    res = []
    char = ''
    opt = ''
    tag = False
    for i in str_l:
        if i in key:
            tag = True
            if len(char) != 0:
                res.append(char)
                char = ''
            opt += i

        if not tag:
            char += i

        if tag and i not in key:
            tag = False
            res.append(opt)
            opt = ''
            char += i
    else:
        res.append(char)


    #新增解析like的功能
    if len(res) == 1:
        res = res[0].split('like')
        res.insert(1,'like')
    # print('from in the three_parse res is \033[43;1m%s\033[0m' % res)
    return res

    #第二部分:sql执行
def sql_action(sql_dic):
    '''
    从字典sql_dic提取命令,分发给具体的命令执行函数去执行
    :param sql_dic:
    :return:
    '''
    return sql_dic.get('func')(sql_dic)


def insert(sql_dic):
    #'insert': [], 'into': ['user'], 'values': ['a,23,11103213123,BOSS']
    # print('insert %s' % sql_dic)
    count = -1
    table = sql_dic.get('into')[0]
    f1 = open('%s' % table,'r',encoding='utf-8')
    for i in f1:
        if len(i.strip()) != 0:
            count += 1
    f1.close()
    with open('%s' % table,'a+',encoding='utf-8') as f2:
        value = sql_dic.get('values')[0]
        if len(value.strip().split(',')) == 4:
            f2.write('\n%s,%s' % (int(count+1),value))
            print('insert successful')
        else:
            print('缺少字段')

    return value



def delete(sql_dic):
    #'delete': [], 'from': ['user'], 'where': [['id', '>', '2']]

    # print('delete %s' % sql_dic)
    table = sql_dic.get('from')[0]
    table_bak = '%s.bak' % table
    with open('%s' % table, 'r', encoding='utf-8') as read_f,open('%s' % table_bak,'w') as write_f:
        lines = read_f.readline().strip()

        tag = True
        for line in read_f:
            if tag:
                write_f.write('%s\n'%lines)
                write_f.flush()
                tag = False
            dic = dict(zip(lines.split(','), line.split(',')))
            filter_res = logic_action(dic, sql_dic.get('where'))
            if not filter_res:
                write_f.write(line)

        write_f.flush()
    os.remove("%s" % table)
    os.rename("%s" % table_bak, "%s" % table)
    print('delete successful')
    return



def update(sql_dic):
    #'update': ['user'], 'set': ['id', '=', "'TT'"], 'where': [['name', 'like', 'dog']]}
    #print('update %s' %sql_dic)
    table = sql_dic.get('update')[0]
    table_bak = '%s.bak' % table
    with open('%s' % table, 'r', encoding='utf-8') as read_f,open('%s' % table_bak,'w') as write_f:
        lines = read_f.readline().strip()
        tag = True
        for line in read_f:
            if tag:
                write_f.write('%s\n' % lines)
                write_f.flush()
                tag = False
            dic = dict(zip(lines.split(','), line.split(',')))
            set = where_parse(sql_dic['set'])

            filter_res = logic_action(dic, sql_dic.get('where'))
            if not filter_res:
                write_f.write(line)
            else:
                old_list = lines.split(',')
                old_index = old_list.index(set[0][0])
                old = line.split(',')[old_index]
                new = set[0][2]
                line = line.replace(old,new)
                write_f.write(line)
        write_f.flush()
    os.remove("%s" % table)
    os.rename("%s" % table_bak, "%s" % table)
    print('update successful')
    return

def select(sql_dic):
    # print('from select sql_dic is %s' % sql_dic)

    #第一部分处理:from
    table = sql_dic.get('from')[0]

    f = open('%s' % table,'r',encoding='utf-8')
    lines = f.readline()

    #第二部分处理:where
    filter_res = where_action(f,sql_dic.get('where'),lines)
    # for i in filter_res:
    #     print('filter res is %s' % i)

    #第三部分处理:limit
    limit_res = limit_action(filter_res,sql_dic.get('limit'))
    # for i in limit_res:
    #     print('limit_res res is %s' % i)
    #第四部处理:select
    search_res = search_action(limit_res,sql_dic.get('select'),lines)

    f.close()
    return search_res



def where_action(f,where_l,lines):
    # print('in where_action \033[41;1m%s\033[0m'%whele_l)

    res = []
    if len(where_l) != 0:
        for i in f:
            dic = dict(zip(lines.strip().split(','),i.strip().split(',')))   #文件中的一条记录和where_L全部内容进行比对

            #逻辑判断
            logic_res = logic_action(dic,where_l)
            if logic_res:
                res.append(i.split(','))
    else:
        res = f.readlines()

    return res




def logic_action(dic,where_l):
    res = []
    # where_l = [['name', 'like', 'alex'], 'or', ['id', '<', '4'], 'or', ['id', '=', '1']]
    for i in where_l:
        if type(i) is list:
            i_k, opt, i_v = i
            if i[1] == '=':
                opt = '%s=' % i[1]
            if dic[i_k].isdigit():
                dic_v = dic[i_k]
            else:
                dic_v = "'%s'" % dic[i_k]
            if opt != 'like':
                i = str(eval("%s%s%s" % (dic_v, opt, i_v)))
            else:
                if i_v in dic_v:
                    i = 'True'
                else:
                    i = 'False'
        res.append(i)
    res = eval(' '.join(res))
    # print('==\033[45;1m%s\033[0m' %(res))
    return res

def limit_action(filter_res,limit_l):
    res = []
    if len(limit_l) != 0:
        index = int(limit_l[0])
        res = filter_res[0:index]
    else:
        res = filter_res

    return res




def search_action(limit_res,select_l,lines):
    res = []
    fileds_l = []
    if select_l[0] == '*':
        res = limit_res
        fileds_l = lines.split(',')

    else:
        for i in limit_res:
            dic = dict(zip(lines.split(','),i))
            r_l = []
            fileds_l = select_l[0].split(',')
            for line in fileds_l:
                r_l.append(dic[line].strip())
            res.append(r_l)

    # print('search_action r_l %s,%s' % (fileds_l,r_l))
    return [fileds_l,res]




if __name__ == '__main__':
    home_page()

猜你喜欢

转载自www.cnblogs.com/watchslowly/p/9022032.html