Nl2sql学习(6):model2代码学习

import math
import json
import re
import random
import numpy as np
from collections import defaultdict

import cn2an
from tqdm import tqdm
from nl2sql.utils import read_data, read_tables, SQL, Query, Question, Table
from keras_bert import get_checkpoint_paths, load_vocabulary, Tokenizer, load_trained_model_from_checkpoint
from keras.utils.data_utils import Sequence
from keras.preprocessing.sequence import pad_sequences
from keras.layers import Input, Lambda, Dense
from keras.models import Model
from keras.optimizers import Adam
from keras.utils import multi_gpu_model
Using TensorFlow backend.
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])

Configuration

train_table_file = 'E:/zym_test/test/nlp/data/train/train.tables.json'
train_data_file = 'E:/zym_test/test/nlp/data/train/train.json'

val_table_file = 'E:/zym_test/test/nlp/data/val/val.tables.json'
val_data_file = 'E:/zym_test/test/nlp/data/val/val.json'

test_table_file = 'E:/zym_test/test/nlp/data/test/test.tables.json'
test_data_file = 'E:/zym_test/test/nlp/data/test/test.json'

# Download pretrained BERT model from https://github.com/ymcui/Chinese-BERT-wwm
bert_model_path = 'E:\\zym_test\\test\\nlp\\base-line\\chinese_wwm_ext_L-12_H-768_A-12'
paths = get_checkpoint_paths(bert_model_path)

task1_file = 'task1_output.json'

数据的读取

train_tables = read_tables(train_table_file)
train_data = read_data(train_data_file, train_tables)

val_tables = read_tables(val_table_file)
val_data = read_data(val_data_file, val_tables)

test_tables = read_tables(test_table_file)
test_data = read_data(test_data_file, test_tables)

构建Dataset

# is_float():判断是否为数字
def is_float(value):
    try:
        float(value)
        return True
    except ValueError:
        return False

# 将中文数字转换为阿拉伯数字
def cn_to_an(string):
    try:
        # normal 表示“一二三”也可以转换为“123”
        return str(cn2an.cn2an(string, 'normal'))
    except ValueError:
        return string

# 将阿拉伯数字转换为中文数字
def an_to_cn(string):
    try:
        return str(cn2an.an2cn(string))
    except ValueError:
        return string

# 将字符串转换为数字
def str_to_num(string):
    try:
        float_val = float(cn_to_an(string))
        if int(float_val) == float_val:   
            return str(int(float_val))
        else:
            return str(float_val)
    except ValueError:
        return None

# 年份的转化(仅有数字后两位,再加2000,为2000年以后的时间) 
def str_to_year(string):
    year = string.replace('年', '')
    year = cn_to_an(year)
    if is_float(year) and float(year) < 1900:
        year = int(year) + 2000
        return str(year)
    else:
        return None

# 加载json文件
def load_json(json_file):
    result = []
    if json_file:
        with open(json_file) as file:
            for line in file:
                result.append(json.loads(line))
    return result

小demo

print("is_float:{}->{}".format('abc', is_float('abc')))
print("is_float:{}->{}".format('1', is_float('1')))
print("is_float:{}->{}".format('1.', is_float('1.')))

print("cn_to_an:{}->{}".format('五百五十', cn_to_an('五百五十')))
print("cn_to_an:{}->{}".format('abc', cn_to_an('abc')))
print("cn_to_an:{}->{}".format('一二三', cn_to_an('一二三')))

print("an_to_cn:{}->{}".format('1', an_to_cn('1')))
print("an_to_cn:{}->{}".format('123', an_to_cn('123')))
print("an_to_cn:{}->{}".format('cb', an_to_cn('cb')))

print("str_to_num:{}->{}".format('123', str_to_num('123')))
print("str_to_num:{}->{}".format('cb', str_to_num('cb')))

print("str_to_year:{}->{}".format('20年',str_to_year('20年')))
print("str_to_year:{}->{}".format('2020年',str_to_year('2020年')))
print("str_to_year:{}->{}".format('1800年',str_to_year('1800年')))
print("str_to_year:{}->{}".format('一九年',str_to_year('一九年')))
print("str_to_year:{}->{}".format('二零一九年',str_to_year('二零一九年')))
is_float:abc->False
is_float:1->True
is_float:1.->True
cn_to_an:五百五十->550
cn_to_an:abc->abc
cn_to_an:一二三->123
an_to_cn:1->一
an_to_cn:123->一百二十三
an_to_cn:cb->cb
str_to_num:123->123
str_to_num:cb->None
str_to_year:20年->2020
str_to_year:2020年->None
str_to_year:1800年->3800
str_to_year:一九年->2019
str_to_year:二零一九年->None
class QuestionCondPair:
    def __init__(self, query_id, question, cond_text, cond_sql, label):
        self.query_id = query_id
        self.question = question
        self.cond_text = cond_text
        self.cond_sql = cond_sql
        self.label = label

    def __repr__(self):
        repr_str = ''
        repr_str += 'query_id: {}\n'.format(self.query_id)
        repr_str += 'question: {}\n'.format(self.question)
        repr_str += 'cond_text: {}\n'.format(self.cond_text)
        repr_str += 'cond_sql: {}\n'.format(self.cond_sql)
        repr_str += 'label: {}\n'.format(self.label)
        return repr_str
class NegativeSampler:
    """
    从 question - cond pairs 中采样
    """
    def __init__(self, neg_sample_ratio=10):
        self.neg_sample_ratio = neg_sample_ratio
    
    # 区分正样本和负样本,抽取部分负样本与正样本组成新样本
    def sample(self, data):
        positive_data = [d for d in data if d.label == 1]
        negative_data = [d for d in data if d.label == 0]
        negative_sample = random.sample(negative_data, 
                                        len(positive_data) * self.neg_sample_ratio)
        return positive_data + negative_sample

    
class FullSampler:
    """
    不抽样,返回所有的 pairs
    
    """
    def sample(self, data):
        return data

小demo

a = [d for d in [1,2,3] if d < 10]
c = []
for d in [1,2,3]:
    if d<10:
        c.append(d)
print(a,c)
print(a+c)
[1, 2, 3] [1, 2, 3]
[1, 2, 3, 1, 2, 3]
class CandidateCondsExtractor:
    """
    params:
        - share_candidates: 在同 table 同 column 中共享 real 型 candidates
    """
    CN_NUM = '〇一二三四五六七八九零壹贰叁肆伍陆柒捌玖貮两'
    CN_UNIT = '十拾百佰千仟万萬亿億兆点'
    
    def __init__(self, share_candidates=True):
        self.share_candidates = share_candidates
        self._cached = False
    
    # 构建候选缓存
    def build_candidate_cache(self, queries):
        # defaultdict(set)表示当字典:self.cache不存在所索引的key时,那么返回set()
        self.cache = defaultdict(set)
        print('building candidate cache')
        
        
        # tqdm是python进度条,用问题总数作为进度条总长度
        # query_id, query为问题的id索引和问题
        for query_id, query in tqdm(enumerate(queries), total=len(queries)):
            # 文本中的数字、年份信息提取出来
            value_in_question = self.extract_values_from_text(query.question.text)
            # 从table的每一列提取与问题中有相同字的值
            for col_id, (col_name, col_type) in enumerate(query.table.header):
                value_in_column = self.extract_values_from_column(query, col_id)
                if col_type == 'text':
                    cond_values = value_in_column
                elif col_type == 'real':
                    if len(value_in_column) == 1: 
                        cond_values = value_in_column + value_in_question
                    else:
                        cond_values = value_in_question
                cache_key = self.get_cache_key(query_id, query, col_id)
                self.cache[cache_key].update(cond_values)
        self._cached = True
    
    def get_cache_key(self, query_id, query, col_id):
        if self.share_candidates:
            return (query.table.id, col_id)
        else:
            return (query_id, query.table.id, col_id)
    
    # 将年份信息提取出来
    def extract_year_from_text(self, text):
        values = []
        # 从text寻找'数字' '数字' '年'
        num_year_texts = re.findall(r'[0-9][0-9]年', text)
         # 将两位数字加‘年’与 20 合并
        values += ['20{}'.format(text[:-1]) for text in num_year_texts]
        # 将中文文本中的年份信息检索出来
        cn_year_texts = re.findall(r'[{}][{}]年'.format(self.CN_NUM, self.CN_NUM), text)
        # 将中文数字转化为阿拉伯数字
        cn_year_values = [str_to_year(text) for text in cn_year_texts]
        values += [value for value in cn_year_values if value is not None]
        return values
    # 将数字、符号信息提取出来
    def extract_num_from_text(self, text):
        values = []
        num_values = re.findall(r'[-+]?[0-9]*\.?[0-9]+', text)
        values += num_values
        
        cn_num_unit = self.CN_NUM + self.CN_UNIT
        cn_num_texts = re.findall(r'[{}]*\.?[{}]+'.format(cn_num_unit, cn_num_unit), text)
        cn_num_values = [str_to_num(text) for text in cn_num_texts]
        values += [value for value in cn_num_values if value is not None]
    
        cn_num_mix = re.findall(r'[0-9]*\.?[{}]+'.format(self.CN_UNIT), text)
        for word in cn_num_mix:
            num = re.findall(r'[-+]?[0-9]*\.?[0-9]+', word)
            for n in num:
                word = word.replace(n, an_to_cn(n))
            str_num = str_to_num(word)
            if str_num is not None:
                values.append(str_num)
        return values
    
    def extract_values_from_text(self, text):
        values = []
        values += self.extract_year_from_text(text)
        values += self.extract_num_from_text(text)
        return list(set(values))
    
    # 从问题中提取字符,然后与表格中的字符进行对比,提取在问题中出现的字符
    def extract_values_from_column(self, query, col_ids):
        question = query.question.text
        question_chars = set(query.question.text)
        unique_col_values = set(query.table.df.iloc[:, col_ids].astype(str))
        select_col_values = [v for v in unique_col_values 
                             if (question_chars & set(v))]
        return select_col_values
text = "今年是2020年,明年是2021年,去年是二零一九年"
text_1 = '我的天哪哈哈哈18年,二零二二年,一九年'
CN_NUM = '〇一二三四五六七八九零壹贰叁肆伍陆柒捌玖貮两'
CN_UNIT = '十拾百佰千仟万萬亿億兆点'
    
values = []
num_year_texts = re.findall(r'[0-9][0-9]年', text)
print(num_year_texts)
values += ['20{}'.format(text[:-1]) for text in num_year_texts]
print(values)
for text in num_year_texts:
    print(text)
    # 去掉年份
    print(text[:-1])
cn_year_texts = re.findall(r'[{}][{}]年'.format(CN_NUM, CN_NUM), text_1)
print(cn_year_texts)
cn_year_values = [str_to_year(text) for text in cn_year_texts]
print(cn_year_values)
['20年', '21年']
['2020', '2021']
20年
20
21年
21
['二二年', '一九年']
['2022', '2019']
class QuestionCondPairsDataset:
    """
    question - cond pairs 数据集
    """
    OP_PATTERN = {
        'real':
        [
            {'cond_op_idx': 0, 'pattern': '{col_name}大于{value}'},
            {'cond_op_idx': 1, 'pattern': '{col_name}小于{value}'},
            {'cond_op_idx': 2, 'pattern': '{col_name}是{value}'}
        ],
        'text':
        [
            {'cond_op_idx': 2, 'pattern': '{col_name}是{value}'}
        ]
    }    
    
    def __init__(self, queries, candidate_extractor, has_label=True, model_1_outputs=None):
        self.candidate_extractor = candidate_extractor
        self.has_label = has_label
        self.model_1_outputs = model_1_outputs
        self.data = self.build_dataset(queries)
        
    def build_dataset(self, queries):
        if not self.candidate_extractor._cached:
            self.candidate_extractor.build_candidate_cache(queries)
            
        pair_data = []
        for query_id, query in enumerate(queries):
            select_col_id = self.get_select_col_id(query_id, query)
            for col_id, (col_name, col_type) in enumerate(query.table.header):
                if col_id not in select_col_id:
                    continue
                    
                cache_key = self.candidate_extractor.get_cache_key(query_id, query, col_id)
                values = self.candidate_extractor.cache.get(cache_key, [])
                pattern = self.OP_PATTERN.get(col_type, [])
                pairs = self.generate_pairs(query_id, query, col_id, col_name, 
                                               values, pattern)
                pair_data += pairs
        return pair_data
    
    def get_select_col_id(self, query_id, query):
        if self.model_1_outputs:
            select_col_id = [cond_col for cond_col, *_ in self.model_1_outputs[query_id]['conds']]
        elif self.has_label:
            select_col_id = [cond_col for cond_col, *_ in query.sql.conds]
        else:
            select_col_id = list(range(len(query.table.header)))
        return select_col_id
            
    def generate_pairs(self, query_id, query, col_id, col_name, values, op_patterns):
        pairs = []
        for value in values:
            for op_pattern in op_patterns:
                cond = op_pattern['pattern'].format(col_name=col_name, value=value)
                cond_sql = (col_id, op_pattern['cond_op_idx'], value)
                real_sql = {}
                if self.has_label:
                    real_sql = {tuple(c) for c in query.sql.conds}
                label = 1 if cond_sql in real_sql else 0
                pair = QuestionCondPair(query_id, query.question.text,
                                        cond, cond_sql, label)
                pairs.append(pair)
        return pairs
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
task1_result = load_json(task1_file)

tr_qc_pairs = QuestionCondPairsDataset(train_data, 
                                       candidate_extractor=CandidateCondsExtractor(share_candidates=False))

te_qc_pairs = QuestionCondPairsDataset(test_data, 
                                       candidate_extractor=CandidateCondsExtractor(share_candidates=True),
                                       has_label=False,
                                       model_1_outputs=task1_result)
  0%|▏                                                                             | 81/41522 [00:00<00:51, 804.12it/s]

building candidate cache


100%|███████████████████████████████████████████████████████████████████████████| 41522/41522 [00:58<00:00, 706.46it/s]
  2%|█▎                                                                             | 65/4086 [00:00<00:06, 645.07it/s]

building candidate cache


100%|█████████████████████████████████████████████████████████████████████████████| 4086/4086 [00:06<00:00, 656.83it/s]

构建模型

class SimpleTokenizer(Tokenizer):
    def _tokenize(self, text):
        R = []
        for c in text:
            if c in self._token_dict:
                R.append(c)
            elif self._is_space(c):
                R.append('[unused1]')
            else:
                R.append('[UNK]')
        return R

            
def construct_model(paths, use_multi_gpus=False):
    token_dict = load_vocabulary(paths.vocab)
    tokenizer = SimpleTokenizer(token_dict)

    bert_model = load_trained_model_from_checkpoint(
        paths.config, paths.checkpoint, seq_len=None)
    for l in bert_model.layers:
        l.trainable = True

    x1_in = Input(shape=(None,), name='input_x1', dtype='int32')
    x2_in = Input(shape=(None,), name='input_x2')
    x = bert_model([x1_in, x2_in])
    x_cls = Lambda(lambda x: x[:, 0])(x)
    y_pred = Dense(1, activation='sigmoid', name='output_similarity')(x_cls)
    # 类似sequentical
    model = Model([x1_in, x2_in], y_pred)
    if use_multi_gpus:
        print('using multi-gpus')
        model = multi_gpu_model(model, gpus=2)

    model.compile(loss={'output_similarity': 'binary_crossentropy'},
                  optimizer=Adam(1e-5),
                  metrics={'output_similarity': 'accuracy'})

    return model, tokenizer
model, tokenizer = construct_model(paths)
model.summary()
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\keras\backend\tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_x1 (InputLayer)           (None, None)         0                                            
__________________________________________________________________________________________________
input_x2 (InputLayer)           (None, None)         0                                            
__________________________________________________________________________________________________
model_2 (Model)                 (None, None, 768)    101677056   input_x1[0][0]                   
                                                                 input_x2[0][0]                   
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 768)          0           model_2[1][0]                    
__________________________________________________________________________________________________
output_similarity (Dense)       (None, 1)            769         lambda_1[0][0]                   
==================================================================================================
Total params: 101,677,825
Trainable params: 101,677,825
Non-trainable params: 0
__________________________________________________________________________________________________

构建输入数据

class QuestionCondPairsDataseq(Sequence):
    def __init__(self, dataset, tokenizer, is_train=True, max_len=120, 
                 sampler=None, shuffle=False, batch_size=32):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.is_train = is_train
        self.max_len = max_len
        self.sampler = sampler
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.on_epoch_end()       
    
    def _pad_sequences(self, seqs, max_len=None):
        return pad_sequences(seqs, maxlen=max_len, padding='post', truncating='post')
    
    def __getitem__(self, batch_id):
        batch_data_indices = \
            self.global_indices[batch_id * self.batch_size: (batch_id + 1) * self.batch_size]
        batch_data = [self.data[i] for i in batch_data_indices]

        X1, X2 = [], []
        Y = []
        
        for data in batch_data:
            x1, x2 = self.tokenizer.encode(first=data.question.lower(), 
                                           second=data.cond_text.lower())
            X1.append(x1)
            X2.append(x2)
            if self.is_train:
                Y.append([data.label])
    
        X1 = self._pad_sequences(X1, max_len=self.max_len)
        X2 = self._pad_sequences(X2, max_len=self.max_len)
        inputs = {'input_x1': X1, 'input_x2': X2}
        if self.is_train:
            Y = self._pad_sequences(Y, max_len=1)
            outputs = {'output_similarity': Y}
            return inputs, outputs
        else:
            return inputs
                    
    def on_epoch_end(self):
        self.data = self.sampler.sample(self.dataset)
        self.global_indices = np.arange(len(self.data))
        if self.shuffle:
            np.random.shuffle(self.global_indices)
    
    def __len__(self):
        return math.ceil(len(self.data) / self.batch_size)
tr_qc_pairs_seq = QuestionCondPairsDataseq(tr_qc_pairs, tokenizer, 
                                           sampler=NegativeSampler(), shuffle=True)

te_qc_pairs_seq = QuestionCondPairsDataseq(te_qc_pairs, tokenizer, 
                                           sampler=FullSampler(), shuffle=False, batch_size=128)

训练模型

model.fit_generator(tr_qc_pairs_seq, epochs=5, workers=4)

预测测试集

te_result = model.predict_generator(te_qc_pairs_seq, verbose=1)

对任务二做预测

def merge_result(qc_pairs, result, threshold):
    select_result = defaultdict(set)
    for pair, score in zip(qc_pairs, result):
        if score > threshold:
            select_result[pair.query_id].update([pair.cond_sql])
    return dict(select_result)

task2_result = merge_result(te_qc_pairs, te_result, threshold=0.995)   

最终输出

final_output_file = 'final_output.json'
with open(final_output_file, 'w') as f:
    for query_id, pred_sql in enumerate(task1_result):
        cond = list(task2_result.get(query_id, []))
        pred_sql['conds'] = cond
        json_str = json.dumps(pred_sql, ensure_ascii=False)
        f.write(json_str + '\n')
发布了35 篇原创文章 · 获赞 3 · 访问量 2483

猜你喜欢

转载自blog.csdn.net/Smile_mingm/article/details/104837430