Nl2sql学习(5):model1代码学习(详细注释)

整体流程

  1. 数据的读取

  2. 数据的处理

    • 输入:问句和Table表头的数字化(Tokenization)
    • 标签:sql label表达的修改
    • 模型所需数据的构建
  3. 构建模型

    • 输入数据的bert-encoding
    • encoding后经全连接层输出
  4. 模型训练:设置callbacks选择最佳模型

  5. 预测

代码

import os
import re
import json
import math
import numpy as np
from tqdm import tqdm

from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths

import keras.backend as K
from keras.layers import Input, Dense, Lambda, Multiply, Masking, Concatenate
from keras.models import Model
from keras.preprocessing.sequence import pad_sequences
from keras.callbacks import Callback, ModelCheckpoint
from keras.utils.data_utils import Sequence
from keras.utils import multi_gpu_model

from nl2sql.utils import read_data, read_tables, SQL, MultiSentenceTokenizer, Query, Question, Table
from nl2sql.utils.optimizer import RAdam

Configuration

train_table_file = 'E:/zym_test/test/nlp/data/train/train.tables.json'
train_data_file = 'E:/zym_test/test/nlp/data/train/train.json'

val_table_file = 'E:/zym_test/test/nlp/data/val/val.tables.json'
val_data_file = 'E:/zym_test/test/nlp/data/val/val.json'

test_table_file = 'E:/zym_test/test/nlp/data/test/test.tables.json'
test_data_file = 'E:/zym_test/test/nlp/data/test/test.json'

# Download pretrained BERT model from https://github.com/ymcui/Chinese-BERT-wwm
bert_model_path = 'E:\\zym_test\\test\\nlp\\base-line\\chinese_wwm_ext_L-12_H-768_A-12'
paths = get_checkpoint_paths(bert_model_path)

数据的读取与展示

train_tables = read_tables(train_table_file)
train_data = read_data(train_data_file, train_tables)

val_tables = read_tables(val_table_file)
val_data = read_data(val_data_file, val_tables)

test_tables = read_tables(test_table_file)
test_data = read_data(test_data_file, test_tables)
sample_query = train_data[0]
sample_query
影片名称 周票房(万) 票房占比(%) 场均人次
0 死侍2:我爱我家 10637.3 25.8 5.0
1 白蛇:缘起 10503.8 25.4 7.0
2 大黄蜂 6426.6 15.6 6.0
3 密室逃生 5841.4 14.2 6.0
4 “大”人物 3322.9 8.1 5.0
5 家和万事惊 635.2 1.5 25.0
6 钢铁飞龙之奥特曼崛起 595.5 1.4 3.0
7 海王 500.3 1.2 5.0
8 一条狗的回家路 360.0 0.9 4.0
9 掠食城市 356.6 0.9 3.0

二零一九年第四周大黄蜂和密室逃生这两部影片的票房总占比是多少呀
sel: [2]
agg: ['SUM']
cond_conn_op: 'or'
conds: [[0, '==', '大黄蜂'], [0, '==', '密室逃生']]
sample_query.table
影片名称 周票房(万) 票房占比(%) 场均人次
0 死侍2:我爱我家 10637.3 25.8 5.0
1 白蛇:缘起 10503.8 25.4 7.0
2 大黄蜂 6426.6 15.6 6.0
3 密室逃生 5841.4 14.2 6.0
4 “大”人物 3322.9 8.1 5.0
5 家和万事惊 635.2 1.5 25.0
6 钢铁飞龙之奥特曼崛起 595.5 1.4 3.0
7 海王 500.3 1.2 5.0
8 一条狗的回家路 360.0 0.9 4.0
9 掠食城市 356.6 0.9 3.0
sample_query.question
二零一九年第四周大黄蜂和密室逃生这两部影片的票房总占比是多少呀
sample_query.sql

sel: [2]
agg: [‘SUM’]
cond_conn_op: 'or’
conds: [[0, '== ', ‘大黄蜂’], [0, ‘==’, ‘密室逃生’]]

Query的Tokenization与展示

# 去除句子中的一些符号,洗白白
def remove_brackets(s):
    return re.sub(r'[\(\(].*[\)\)]', '', s)


# 对 query tokenize 并将其 token 转换为id
class QueryTokenizer(MultiSentenceTokenizer):
    """
    Tokenize:question + table header
    
    使用[unused11]和[unused12]用来区分不同类型的列(数字or文本)
    """
    col_type_token_dict = {'text': '[unused11]', 'real': '[unused12]'}

    def tokenize(self, query: Query, col_orders=None):
        """
        输入参数是:query和是否重排了列的顺序
        
        返回的量是:(为bert-encoding准备)
        query的token id
        query的segment id
        query中table列的标记(text/real)的header id
        
        """
        # 1.question token(文字+空格+unk)
        # question token 加 cls的token
        question_tokens = [self._token_cls] + self._tokenize(
            query.question.text)

        # 2.table header token
        header = []
        header_tokens = []
        if col_orders is None:
            col_orders = np.arange(len(query.table.header))
        for i in col_orders:
            # header = (col_name, col_type)
            header.append(query.table.header[i])
        for col_name, col_type in header:
            # table列标记的token
            col_type_token = self.col_type_token_dict[col_type]
            # 将列名洗白白
            col_name = remove_brackets(col_name)
            # 列名tokenize
            col_name_tokens = self._tokenize(col_name)
            # 列的整体token(标记+列名)
            col_tokens = [col_type_token] + col_name_tokens
            # 添加到最终的所有列list
            header_tokens.append(col_tokens)

        # 3.all token
        all_tokens = [question_tokens] + header_tokens
        # 4.return  ._pack(all_tokens)
        # _pack会为每组token后面加上sep
        # 并且返回token组和token组的长度
        return self._pack(*all_tokens)

    def encode(self, query: Query, col_orders=None):
        tokens, tokens_lens = self.tokenize(query, col_orders)
        token_ids = self._convert_tokens_to_ids(tokens)
        segment_ids = [0] * len(token_ids)
        header_indices = np.cumsum(tokens_lens)
        return token_ids, segment_ids, header_indices[:-1]
# 将query tokenize
# QueryTokenizer的父类MultiSentenceTokenizer的父类Tokenizer需要输入token_dict
# 所有词的token_dict
token_dict = load_vocabulary(paths.vocab)
query_tokenizer = QueryTokenizer(token_dict)
# '-'.join(a, b, c) -> a-b-c
print('Output Tokens:\n{}\n'.format(' '.join(query_tokenizer.tokenize(sample_query)[0])))
# header_ids是指整个token中[text]、[real]的位置
print('Output token_ids:\n{}\n\nOutput segment_ids:\n{}\n\nOutput header_ids:\n{}'
      .format(*query_tokenizer.encode(sample_query)))
Output Tokens:
[CLS] 二 零 一 九 年 第 四 周 大 黄 蜂 和 密 室 逃 生 这 两 部 影 片 的 票 房 总 占 比 是 多 少 呀 [SEP] [unused11] 影 片 名 称 [SEP] [unused12] 周 票 房 [SEP] [unused12] 票 房 占 比 [SEP] [unused12] 场 均 人 次 [SEP]

Output token_ids:
[101, 753, 7439, 671, 736, 2399, 5018, 1724, 1453, 1920, 7942, 6044, 1469, 2166, 2147, 6845, 4495, 6821, 697, 6956, 2512, 4275, 4638, 4873, 2791, 2600, 1304, 3683, 3221, 1914, 2208, 1435, 102, 11, 2512, 4275, 1399, 4917, 102, 12, 1453, 4873, 2791, 102, 12, 4873, 2791, 1304, 3683, 102, 12, 1767, 1772, 782, 3613, 102]

Output segment_ids:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Output header_ids:
[33 39 44 50]

小demo:

a = [1,2,3]
b = [a]+[4]
b
[[1, 2, 3], 4]

Sql语句的转换与展示

class SqlLabelEncoder:
    """
    将sql语句转化为训练所需的形式
    """
    def encode(self, sql:SQL, num_cols):
        # or / and / ''
        cond_conn_op_label = sql.cond_conn_op
        
        # table每一列的agg都由数字对应代表,初始状态都为len(SQL.agg_sql_dict),即无操作
        sel_agg_label = np.ones(num_cols, dtype='int32')* len(SQL.agg_sql_dict)
        # 将sel_agg_label的每一列更新为对应的操作
        for col_id, agg_op in zip(sql.sel, sql.agg):
            if col_id< num_cols:
                sel_agg_label[col_id] = agg_op
         
        # table每一列的cond_op都由数字对应代表,初始状态都为len(SQL.agg_sql_dict),即无操作
        cond_op_label = np.ones(num_cols, dtype='int32')* len(SQL.agg_sql_dict)
        # 将cond_op_label的每一列更新为对应的操作
        for col_id, cond_op, cond_value in sql.conds:
            if col_id < num_cols:
                cond_op_label[col_id] = cond_op
                
        return cond_conn_op_label, sel_agg_label, cond_op_label
    
    def decode(self, cond_conn_op_label, sel_agg_label, cond_op_label):
        cond_conn_op = int(cond_conn_op_label)
        sel, agg, conds = [], [], []
        for col_id, (agg_op, cond_op) in enumerate(zip(sel_agg_label, cond_op_label)):
            # 此if表示agg_op如果不是no_op,那么把操作和其列添加
            if agg_op < len(SQL.agg_sql_dict):
                sel.append(col_id)
                agg.append(int(agg_op))
            if cond_op < len(SQL.op_sql_dict):
                conds.append([col_id, int(cond_op)])
        return {
            'sel' : sel,
            'agg' : agg,
            'cond_conn_op' : cond_conn_op,
            'conds' : conds
        }
# 实例化
label_encoder = SqlLabelEncoder()
label_encoder.encode(sample_query.sql, num_cols=len(sample_query.table.header))
(2, array([6, 6, 5, 6]), array([2, 6, 6, 6]))
label_encoder.decode(*label_encoder.encode(sample_query.sql, num_cols=len(sample_query.table.header)))
{'sel': [2], 'agg': [5], 'cond_conn_op': 2, 'conds': [[0, 2]]}

小demo

x = np.ones(5, dtype='int32')
x_1 = x*len([1,2])
print(x,x_1)
[1 1 1 1 1] [2 2 2 2 2]
a= [1,2,3,4]
b= [5,6,7,8]
c = zip(a,b)
d = enumerate(c)
for x,y in d:
    print(x,y)
0 (1, 5)
1 (2, 6)
2 (3, 7)
3 (4, 8)

输入模型的训练数据及其展示

class DataSequence(Sequence):
    """
    1.产生batch
    2.batch里有输入和标签(输出)
    3.输入有token_ids、segment_ids、header_ids、header_mask
    4.输出有sel_agg、cond_op、cond_conn_op
    """
    
    def __init__(self, 
                 data, 
                 tokenizer, 
                 label_encoder, 
                 is_train=True, 
                 max_len=160, 
                 batch_size=32, 
                 shuffle=True, 
                 shuffle_header=True, 
                 global_indices=None):
        # input data
        self.data = data
        self.batch_size = batch_size
        # query
        self.tokenizer = tokenizer
        # label
        self.label_encoder = label_encoder
        # 其他
        self.shuffle = shuffle
        self.shuffle_header = shuffle_header
        self.is_train = is_train
        self.max_len = max_len 
        
        # 构建所有data的索引
        if global_indices is None:
            self._global_indices = np.arange(len(data))
        else:
            self._global_indices = global_indices
            
        if shuffle:
            np.random.shuffle(self._global_indices)
    
    # 将数据变为等长
    def _pad_sequences(self, seqs, max_len=None):
         # post表示在末尾补0或在末尾截断,pre表示补/截断在前面
        padded = pad_sequences(seqs, maxlen=None, padding='post', truncating='post')
        if max_len is not None:
            padded = padded[:, :max_len]
        return padded
    
    def __getitem__(self, batch_id):
        # batch_data的索引
        batch_data_indices = self._global_indices[batch_id * self.batch_size: (batch_id + 1) * self.batch_size]
        batch_data = []
        # 给batch里添加query,也就是data
        for i in batch_data_indices:
            batch_data.append(self.data[i])
        
        # Input data
        TOKEN_IDS, SEGMENT_IDS = [], []
        HEADER_IDS, HEADER_MASK = [], []
        # Lable data
        COND_CONN_OP = []
        SEL_AGG = []
        COND_OP = []
        
        # 遍历batch_data里的数据,生成input and output
        for query in batch_data:
            question = query.question.text
            table = query.table
            col_orders = np.arange(len(table.header))
            if self.shuffle_header:
                np.random.shuffle(col_orders)
                
            token_ids, segment_ids, header_ids = self.tokenizer.encode(query, col_orders)  
            # header_id是header的序号,不与token_ids等长,故不能用_pad_sequences去除冗余
            # 所以要遍历其中的header值来去除
            header_ids_1 = []
            for hid in header_ids:
                if hid < self.max_len:
                    header_ids_1.append(hid)
            header_ids = header_ids_1
            header_mask = [1] * len(header_ids)
            col_orders = col_orders[: len(header_ids)]
            
            TOKEN_IDS.append(token_ids)
            SEGMENT_IDS.append(segment_ids)
            HEADER_IDS.append(header_ids)
            HEADER_MASK.append(header_mask)
            
            # 当is_train为False时,只产生输入,不产生标签(输出)
            if not self.is_train:
                continue
                
            sql = query.sql
            cond_conn_op, sel_agg, cond_op = self.label_encoder.encode(sql, num_cols=len(table.header))
            # sel_agg里的顺序按照col_orders重新排列
            sel_agg = sel_agg[col_orders]
            cond_op = cond_op[col_orders]
            
            COND_CONN_OP.append(cond_conn_op)
            SEL_AGG.append(sel_agg)
            COND_OP.append(cond_op)
            
        TOKEN_IDS = self._pad_sequences(TOKEN_IDS, max_len=self.max_len)
        SEGMENT_IDS = self._pad_sequences(SEGMENT_IDS, max_len=self.max_len)
        HEADER_IDS = self._pad_sequences(HEADER_IDS)
        HEADER_MASK = self._pad_sequences(HEADER_MASK)
        
       
        inputs = {
            'input_token_ids': TOKEN_IDS,
            'input_segment_ids': SEGMENT_IDS,
            'input_header_ids': HEADER_IDS,
            'input_header_mask': HEADER_MASK
        }
        
        if self.is_train:
            SEL_AGG = self._pad_sequences(SEL_AGG)
            SEL_AGG = np.expand_dims(SEL_AGG, axis=-1)
            COND_CONN_OP = np.expand_dims(COND_CONN_OP, axis=-1)
            COND_OP = self._pad_sequences(COND_OP)
            COND_OP = np.expand_dims(COND_OP, axis=-1)
            
            outputs = {
                'output_sel_agg': SEL_AGG,
                'output_cond_conn_op': COND_CONN_OP,
                'output_cond_op': COND_OP
            }
            return inputs, outputs
        else:
            return inputs
    
    # 训练多少个batch
    def __len__(self):
        # math.ceil:输出一个大于或等于输入参数的最小整数
        return math.ceil(len(self.data) / self.batch_size)
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self._global_indices)
train_seq = DataSequence(train_data, query_tokenizer, label_encoder, shuffle=False, max_len=160, batch_size=2)
# 取第一个batch展示
sample_batch_inputs, sample_batch_outputs = train_seq[0]
for name, data in sample_batch_inputs.items():
    print('{} : shape{}'.format(name, data.shape))
    print(data,'\n')
    
for name, data in sample_batch_outputs.items():
    print('{} : shape{}'.format(name, data.shape))
    print(data,'\n')
input_token_ids : shape(2, 57)
[[ 101  753 7439  671  736 2399 5018 1724 1453 1920 7942 6044 1469 2166
  2147 6845 4495 6821  697 6956 2512 4275 4638 4873 2791 2600 1304 3683
  3221 1914 2208 1435  102   11 2512 4275 1399 4917  102   12 1453 4873
  2791  102   12 1767 1772  782 3613  102   12 4873 2791 1304 3683  102
     0]
 [ 101  872 1962 8024  872 4761 6887  791 2399 5018 1724 1453 2166 2147
  6845 4495 8024 6820 3300 6929 6956 1920 7942 6044 2124  812 4873 2791
  2600 4638 1304 3683 1408  102   12 4873 2791 1304 3683  102   12 1767
  1772  782 3613  102   12 1453 4873 2791  102   11 2512 4275 1399 4917
   102]] 

input_segment_ids : shape(2, 57)
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]] 

input_header_ids : shape(2, 4)
[[33 39 44 50]
 [34 40 46 51]] 

input_header_mask : shape(2, 4)
[[1 1 1 1]
 [1 1 1 1]] 

output_sel_agg : shape(2, 4, 1)
[[[6]
  [6]
  [6]
  [5]]

 [[5]
  [6]
  [6]
  [6]]] 

output_cond_conn_op : shape(2, 1)
[[2]
 [2]] 

output_cond_op : shape(2, 4, 1)
[[[2]
  [6]
  [6]
  [6]]

 [[6]
  [6]
  [6]
  [2]]] 
val_seq = DataSequence(
    data=val_data,
    tokenizer=query_tokenizer,
    label_encoder=label_encoder,
    shuffle_header=False,
    is_train=False, 
    max_len=160, 
    batch_size=2)
# 取第一个batch展示
sample_batch_inputs= val_seq[0]
for name, data in sample_batch_inputs.items():
    print('{} : shape{}'.format(name, data.shape))
    print(data,'\n')
input_token_ids : shape(2, 160)
[[ 101 2769 2682 4761 6887  122  123 2399 2791 1765  772 2458 1355 4638
  3198  952 2124 4638 5318 2190 7030 7770  754  122  121 5445  684 1398
  3683 1872 1920  738 1762 4636 1146  722  122  121  809  677 4638 2900
  3403 3300 1525  763 8043  102   11 2900 3403  102   12 5318 2190 7030
   102   12 1398 3683 1872 7270  102    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0]
 [ 101 2600 1066 3300 1914 2208  702 1814 2356  123  121  122  123 2399
   123 3299 2768  769 7030 4384 3683 1920  754  122  121 2400  684  123
   121  122  123 2399  122 3299 2768  769 7030 4384 3683 2207  754  122
  4638 8043  102   11 1814 2356  102   12  123  121  122  123 2399  125
  3299 2768  769 7030 4384 3683  102   12  123  121  122  123 2399  124
  3299 2768  769 7030 4384 3683  102   12  123  121  122  123 2399  123
  3299 2768  769 7030 4384 3683  102   12  123  121  122  123 2399  122
  3299 2768  769 7030 4384 3683  102   12  123  121  122  122 2399  122
   123 3299 2768  769 7030 4384 3683  102   12  123  121  122  122 2399
   122  122 3299 2768  769 7030 4384 3683  102   12  123  121  122  122
  2399  122  121 3299 2768  769 7030 4384 3683  102   12  123  121  122
   123 2399  125 3299 2768  769]] 

input_segment_ids : shape(2, 160)
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]] 

input_header_ids : shape(2, 9)
[[ 48  52  57   0   0   0   0   0   0]
 [ 45  49  63  77  91 105 120 135 150]] 

input_header_mask : shape(2, 9)
[[1 1 1 0 0 0 0 0 0]
 [1 1 1 1 1 1 1 1 1]] 

小demo

a = [i for i in [1,2,3,4,5]]
b = []
for i in [1,2,3,4,5]:
    b.append(i)
print(a,'\n',b)
[1, 2, 3, 4, 5] 
 [1, 2, 3, 4, 5]
c = [1,2,3]
d = []
d.append(c)
d
[[1, 2, 3]]
a = True
for i in range(2):
    if a == True:
        continue
    print(1)
sl = np.array([1,2,3,4])
print(sl)
e = np.arange(2)
print(e)
e_1 = [2,1]
s = sl[e]
s_1 = sl[e_1]
print(s)
print(s_1)
[1 2 3 4]
[0 1]
[1 2]
[3 2]

构建模型

num_sel_agg = len(SQL.agg_sql_dict) + 1
num_cond_op = len(SQL.op_sql_dict) + 1
num_cond_conn_op = len(SQL.conn_sql_dict)
def seq_gather(x):
    seq, idxs = x
    # 将张量转换为所需类型
    idxs = K.cast(idxs, 'int32')
    return K.tf.batch_gather(seq, idxs)
# 1.将输入送到与训练的bert编码器中进行bert-encoding
# 1.1 加载预训练权重
bert_model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, seq_len=None)
# 1.2 设置bert的每一层都为可训练的
for l in bert_model.layers:
    l.trainable = True
# 1.3 设置Bert层输入所需
inp_token_ids = Input(shape=(None,), name='input_token_ids', dtype='int32')
inp_segment_ids = Input(shape=(None,), name='input_segment_ids', dtype='int32')
inp_header_ids = Input(shape=(None,), name='input_header_ids', dtype='int32')
inp_header_mask = Input(shape=(None, ), name='input_header_mask')
# 1.4 得到bert层输出结果x,x = [batch_size, seq_len, hidden_size] = [None, seq_len, 768]
x = bert_model([inp_token_ids, inp_segment_ids])

# 2.对bert输出的编码信息分类别送入之后的神经网络
# 2.1.1 取出batch中所有的[cls]:根据注意力机制,此时的cls已经可以富含整个句子的信息
#     从句子信息中提取cond_conn_op的隐向量,x_for_cond_conn_op = [batch_size, hidden_size] = [None, 768]
x_for_cond_conn_op = Lambda(lambda x: x[:, 0])(x) 
# 2.1.2 将输出送入一个全连接层得到输出
p_cond_conn_op = Dense(num_cond_conn_op, activation='softmax', name='output_cond_conn_op')(x_for_cond_conn_op)
# 2.2.1 取出batch中所有的header标记(text/real),依据的是inp_header_ids的索引
#    x_for_header = [batch_size, header_len, hidden_size] = [None, header_len, 768]
x_for_header = Lambda(seq_gather, name='header_seq_gather')([x, inp_header_ids]) 
# 2.2.2 将inp_header_mask升维,header_mask= [None, header_len, 1],见小demo
header_mask = Lambda(lambda x: K.expand_dims(x, axis=-1))(inp_header_mask) 
# 2.2.3 keras.layers.Multiply([]):将张量逐元素相乘
x_for_header = Multiply()([x_for_header, header_mask])
x_for_header = Masking()(x_for_header)
#  2.2.4 将x_for_header送入一个全连接层得到输出
p_sel_agg = Dense(num_sel_agg, activation='softmax', name='output_sel_agg')(x_for_header)
# 2.3. 输出cond_op
x_for_cond_op = Concatenate(axis=-1)([x_for_header, p_sel_agg])
p_cond_op = Dense(num_cond_op, activation='softmax', name='output_cond_op')(x_for_cond_op)

# 3.组装模型
model = Model(
    [inp_token_ids, inp_segment_ids, inp_header_ids, inp_header_mask],
    [p_cond_conn_op, p_sel_agg, p_cond_op]
)
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\keras\backend\tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
learning_rate = 1e-5
# 由于最终输出的目标是数字编码,不是one-hot编码,所以loss用sparse_categorical_crossentropy
# 若输出是one-hot编码形式,则用categorical_crossentropy
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=RAdam(lr=learning_rate)
)
model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_token_ids (InputLayer)    (None, None)         0                                            
__________________________________________________________________________________________________
input_segment_ids (InputLayer)  (None, None)         0                                            
__________________________________________________________________________________________________
model_2 (Model)                 (None, None, 768)    101677056   input_token_ids[0][0]            
                                                                 input_segment_ids[0][0]          
__________________________________________________________________________________________________
input_header_ids (InputLayer)   (None, None)         0                                            
__________________________________________________________________________________________________
input_header_mask (InputLayer)  (None, None)         0                                            
__________________________________________________________________________________________________
header_seq_gather (Lambda)      (None, None, 768)    0           model_2[1][0]                    
                                                                 input_header_ids[0][0]           
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, None, 1)      0           input_header_mask[0][0]          
__________________________________________________________________________________________________
multiply_1 (Multiply)           (None, None, 768)    0           header_seq_gather[0][0]          
                                                                 lambda_2[0][0]                   
__________________________________________________________________________________________________
masking_1 (Masking)             (None, None, 768)    0           multiply_1[0][0]                 
__________________________________________________________________________________________________
output_sel_agg (Dense)          (None, None, 7)      5383        masking_1[0][0]                  
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 768)          0           model_2[1][0]                    
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, None, 775)    0           masking_1[0][0]                  
                                                                 output_sel_agg[0][0]             
__________________________________________________________________________________________________
output_cond_conn_op (Dense)     (None, 3)            2307        lambda_1[0][0]                   
__________________________________________________________________________________________________
output_cond_op (Dense)          (None, None, 5)      3880        concatenate_1[0][0]              
==================================================================================================
Total params: 101,688,626
Trainable params: 101,688,626
Non-trainable params: 0
__________________________________________________________________________________________________

小demo

def y(x):
    s, i = x
    print(s)
    print(i)
y([1,[2,3]])
1
[2, 3]
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
    print(sess.run(tf.batch_gather(tensor_a,tensor_c)))
[[1]
 [5]
 [9]]
[[1]
 [4]
 [7]]
def tensor_expand(tensor,i):
    tensor_out = tf.expand_dims(tensor, axis=i)
    sess=tf.Session()
    sess.run(tf.global_variables_initializer())
    a = tensor_out.eval(session=sess)
    return a
for i in [-1,0,1,2]:
    print('axis={}:\n{}\n'.format(i, tensor_expand(tensor_a, i)))

axis=-1:
[[[1]
  [2]
  [3]]

 [[4]
  [5]
  [6]]

 [[7]
  [8]
  [9]]]

axis=0:
[[[1 2 3]
  [4 5 6]
  [7 8 9]]]

axis=1:
[[[1 2 3]]

 [[4 5 6]]

 [[7 8 9]]]

axis=2:
[[[1]
  [2]
  [3]]

 [[4]
  [5]
  [6]]

 [[7]
  [8]
  [9]]]

训练模型

def outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op, header_lens, label_encoder):
    """
    Generate sqls from model outputs
    将验证集的输出转为sql语句
    """
    # 因为网络输出的是一排神经元,故用softmax将其转换为概率,取概率最大的那个的索引
    preds_cond_conn_op = np.argmax(preds_cond_conn_op, axis=-1)
    preds_cond_op = np.argmax(preds_cond_op, axis=-1)

    sqls = []
    
    for cond_conn_op, sel_agg, cond_op, header_len in zip(preds_cond_conn_op, 
                                                          preds_sel_agg, 
                                                          preds_cond_op, 
                                                          header_lens):
        sel_agg = sel_agg[:header_len]
        # force to select at least one column for agg
        # sel_agg[:,:-1]表示去除最后一个no_op神经元结果,.max()从剩余的神经元中
        # 选出值最大的那一个数,当sel_agg中有一个数等于这个最大的数时,将这个数替换为1
        
        sel_agg[sel_agg == sel_agg[:, :-1].max()] = 1
        sel_agg = np.argmax(sel_agg, axis=-1)
        
        sql = label_encoder.decode(cond_conn_op, sel_agg, cond_op)
        sql['conds'] = [cond for cond in sql['conds'] if cond[0] < header_len]
        
        sel = []
        agg = []
        for col_id, agg_op in zip(sql['sel'], sql['agg']):
            if col_id < header_len:
                sel.append(col_id)
                agg.append(agg_op)
                
        sql['sel'] = sel
        sql['agg'] = agg
        sqls.append(sql)
    return sqls

class EvaluateCallback(Callback):
    # 用验证集callback
    def __init__(self, val_dataseq):
        self.val_dataseq = val_dataseq
    
    def on_epoch_end(self, epoch, logs=None):
        pred_sqls = []
        for batch_data in self.val_dataseq:
            header_lens = np.sum(batch_data['input_header_mask'], axis=-1)
            preds_cond_conn_op, preds_sel_agg, preds_cond_op = self.model.predict_on_batch(batch_data)
            sqls = outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op, 
                                   header_lens, val_dataseq.label_encoder)
            pred_sqls += sqls
            
        conn_correct = 0
        agg_correct = 0
        conds_correct = 0
        conds_col_id_correct = 0
        all_correct = 0
        num_queries = len(self.val_dataseq.data)
        
        true_sqls = [query.sql for query in self.val_dataseq.data]
        for pred_sql, true_sql in zip(pred_sqls, true_sqls):
            n_correct = 0
            if pred_sql['cond_conn_op'] == true_sql.cond_conn_op:
                conn_correct += 1
                n_correct += 1
            
            pred_aggs = set(zip(pred_sql['sel'], pred_sql['agg']))
            true_aggs = set(zip(true_sql.sel, true_sql.agg))
            if pred_aggs == true_aggs:
                agg_correct += 1
                n_correct += 1

            pred_conds = set([(cond[0], cond[1]) for cond in pred_sql['conds']])
            true_conds = set([(cond[0], cond[1]) for cond in true_sql.conds])

            if pred_conds == true_conds:
                conds_correct += 1
                n_correct += 1
   
            pred_conds_col_ids = set([cond[0] for cond in pred_sql['conds']])
            true_conds_col_ids = set([cond[0] for cond in true_sql['conds']])
            if pred_conds_col_ids == true_conds_col_ids:
                conds_col_id_correct += 1
            
            if n_correct == 3:
                all_correct += 1
        # 打印评估结果
        print('conn_acc: {}'.format(conn_correct / num_queries))
        print('agg_acc: {}'.format(agg_correct / num_queries))
        print('conds_acc: {}'.format(conds_correct / num_queries))
        print('conds_col_id_acc: {}'.format(conds_col_id_correct / num_queries))
        print('total_acc: {}'.format(all_correct / num_queries))
        
        logs['val_tot_acc'] = all_correct / num_queries
        logs['conn_acc'] = conn_correct / num_queries
        logs['conds_acc'] = conds_correct / num_queries
        logs['conds_col_id_acc'] = conds_col_id_correct / num_queries
# batch_size = NUM_GPUS * 32
batch_size = 32
num_epochs = 30

train_dataseq = DataSequence(
    data=train_data,
    tokenizer=query_tokenizer,
    label_encoder=label_encoder,
    shuffle_header=False,
    is_train=True, 
    max_len=160, 
    batch_size=batch_size
)

val_dataseq = DataSequence(
    data=val_data, 
    tokenizer=query_tokenizer,
    label_encoder=label_encoder,
    is_train=False, 
    shuffle_header=False,
    max_len=160, 
    shuffle=False,
    batch_size=batch_size
)
model_path = 'task1_best_model.h5'
callbacks = [
    EvaluateCallback(val_dataseq),
    ModelCheckpoint(filepath=model_path, 
                    monitor='val_tot_acc', 
                    mode='max', 
                    save_best_only=True, 
                    save_weights_only=True)
]
history = model.fit_generator(train_dataseq, epochs=num_epochs, callbacks=callbacks)

小demo

扫描二维码关注公众号,回复: 9924051 查看本文章
s =np.array( [[1,2,3,4],[5,6,7,8]])
print(s)
x = s[:, :-1]
print(x)
x_1 = x.max()
print(x_1)
print(s == x_1)
s[s == s[:, :-1].max()] = 1
print(s)
v = np.argmax(s, axis=-1)
print(v)

[[1 2 3 4]
 [5 6 7 8]]
[[1 2 3]
 [5 6 7]]
7
[[False False False False]
 [False False  True False]]
[[1 2 3 4]
 [5 6 1 8]]
[3 3]

对测试集进行预测

model.load_weights(model_path)
test_dataseq = DataSequence(
    data=test_data, 
    tokenizer=query_tokenizer,
    label_encoder=label_encoder,
    is_train=False, 
    shuffle_header=False,
    max_len=160, 
    shuffle=False,
    batch_size=batch_size
)
pred_sqls = []

for batch_data in tqdm(test_dataseq):
    header_lens = np.sum(batch_data['input_header_mask'], axis=-1)
    preds_cond_conn_op, preds_sel_agg, preds_cond_op = model.predict_on_batch(batch_data)
    sqls = outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op, 
                           header_lens, val_dataseq.label_encoder)
    pred_sqls += sqls
100%|████████████████████████████████████████████████████████████████████████████████| 128/128 [15:59<00:00,  7.49s/it]
task1_output_file = 'task1_output.json'
with open(task1_output_file, 'w') as f:
    for sql in pred_sqls:
        json_str = json.dumps(sql, ensure_ascii=False)
        f.write(json_str + '\n')
发布了35 篇原创文章 · 获赞 3 · 访问量 2484

猜你喜欢

转载自blog.csdn.net/Smile_mingm/article/details/104824142