Nl2sql学习(1):基于bert的baseline

本文转载自 https://kexue.fm/archives/6771,加入了自己对代码的标注理解

import json
from keras_bert import load_trained_model_from_checkpoint, Tokenizer
import codecs
from keras.layers import *
from keras.models import Model
import keras.backend as K
from keras.optimizers import Adam
from keras.callbacks import Callback
from tqdm import tqdm
import jieba
import editdistance
import re
import numpy as np
import tensorflow as tf
import keras
import pandas as pd
print(tf.__version__)
print(keras.__version__)
1.13.1
2.2.4
'''
{
    "table_id": "a1b2c3d4", # 相应表格的id
    "question": "世茂茂悦府新盘容积率大于1,请问它的套均面积是多少?", # 自然语言问句
    "sql":{ # 真实SQL
        "sel": [7], # SQL选择的列 
        "agg": [0], # 选择的列相应的聚合函数, '0'代表无
        "cond_conn_op": 0, # 条件之间的关系
        "conds": [
            [1, 2, "世茂茂悦府"], # 条件列, 条件类型, 条件值,col_1 == "世茂茂悦府"
            [6, 0, "1"]
        ]
    }
}

# 其中条件运算符、聚合符、连接符分别如下
op_sql_dict = {0:">", 1:"<", 2:"==", 3:"!="}
agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"}
conn_sql_dict = {0:"", 1:"and", 2:"or"}

'''

maxlen = 160
num_agg = 7 # agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM", 6:"不被select"}
num_op = 5 # {0:">", 1:"<", 2:"==", 3:"!=", 4:"不被select"}
num_cond_conn_op = 3 # conn_sql_dict = {0:"", 1:"and", 2:"or"}
learning_rate = 5e-5
min_learning_rate = 1e-5


config_path = 'E:\\zym_test\\test\\nlp\\chinese_wwm_ext_L-12_H-768_A-12\\bert_config.json'
checkpoint_path = 'E:\\zym_test\\test\\nlp\\chinese_wwm_ext_L-12_H-768_A-12\\bert_model.ckpt'
dict_path = 'E:\\zym_test\\test\\nlp\\chinese_wwm_ext_L-12_H-768_A-12\\vocab.txt'

def read_data(data_file, table_file):
    data, tables = [], {}
    with open(data_file,encoding='UTF-8') as f:
        for l in f:
            data.append(json.loads(l))
    with open(table_file,encoding='UTF-8') as f:
        for l in f:
            l = json.loads(l)
            # 观察f后发现,rows、name、title、header、common、ids、types
            # rows是一个表格,里面有具体的值,name:是该表的名称,title未知,
            # header是表的列名,common未知,ids是name的id,types是具体值的类型:是text还是real等
            
            # 创建新的字典 
            # 原来header变为现在的headers
            # 将headers添加索引记录到header2id中,(索引,名字)
            # content为空
            # all_values创建一个空set()
            # rows,将列表保存为数组
            d = {}
            d['headers'] = l['header']
            d['header2id'] = {j: i for i, j in enumerate(d['headers'])}
            d['content'] = {}
            d['all_values'] = set()
            rows = np.array(l['rows'])
            
            # 填充content字典:{列名:该列的值},并且去除了重复的值
            for i, h in enumerate(d['headers']):
                d['content'][h] = set(rows[:, i])
                # 记录所有的值(去除重复):set.update() -> 更新原有set(),并去重 
                d['all_values'].update(d['content'][h])
                
            # hasattr() 函数用于判断对象是否包含对应的属性
            # 去除空位置
            d['all_values'] = set([i for i in d['all_values'] if hasattr(i, '__len__')])
            # {id:d}
            tables[l['id']] = d
    
    return data, tables
train_data, train_tables = read_data('E:/zym_test/test/nlp/data/train/train.json','E:/zym_test/test/nlp/data/train/train.tables.json')
valid_data, valid_tables = read_data('E:/zym_test/test/nlp/data/val/val.json','E:/zym_test/test/nlp/data/val/val.tables.json')
test_data, test_tables = read_data('E:/zym_test/test/nlp/data/test/test.json','E:/zym_test/test/nlp/data/test/test.tables.json')
train_data[0:4]
[{'table_id': '4d29d0513aaa11e9b911f40f24344a08',
  'question': '二零一九年第四周大黄蜂和密室逃生这两部影片的票房总占比是多少呀',
  'sql': {'agg': [5],
   'cond_conn_op': 2,
   'sel': [2],
   'conds': [[0, 2, '大黄蜂'], [0, 2, '密室逃生']]}},
 {'table_id': '4d29d0513aaa11e9b911f40f24344a08',
  'question': '你好,你知道今年第四周密室逃生,还有那部大黄蜂它们票房总的占比吗',
  'sql': {'agg': [5],
   'cond_conn_op': 2,
   'sel': [2],
   'conds': [[0, 2, '大黄蜂'], [0, 2, '密室逃生']]}},
 {'table_id': '4d29d0513aaa11e9b911f40f24344a08',
  'question': '我想你帮我查一下第四周大黄蜂,还有密室逃生这两部电影票房的占比加起来会是多少来着',
  'sql': {'agg': [5],
   'cond_conn_op': 2,
   'sel': [2],
   'conds': [[0, 2, '大黄蜂'], [0, 2, '密室逃生']]}},
 {'table_id': '4d25e6403aaa11e9bdbbf40f24344a08',
  'question': '有几家传媒公司16年为了融资收购其他资产而进行定增的呀',
  'sql': {'agg': [4],
   'cond_conn_op': 1,
   'sel': [1],
   'conds': [[6, 2, '2016'], [7, 2, '融资收购其他资产']]}}]
train_tables[ '4d29d0513aaa11e9b911f40f24344a08' ]
{'headers': ['影片名称', '周票房(万)', '票房占比(%)', '场均人次'],
 'header2id': {'影片名称': 0, '周票房(万)': 1, '票房占比(%)': 2, '场均人次': 3},
 'content': {'影片名称': {'“大”人物',
   '一条狗的回家路',
   '大黄蜂',
   '家和万事惊',
   '密室逃生',
   '掠食城市',
   '死侍2:我爱我家',
   '海王',
   '白蛇:缘起',
   '钢铁飞龙之奥特曼崛起'},
  '周票房(万)': {'10503.8',
   '10637.3',
   '3322.9',
   '356.6',
   '360.0',
   '500.3',
   '5841.4',
   '595.5',
   '635.2',
   '6426.6'},
  '票房占比(%)': {'0.9',
   '1.2',
   '1.4',
   '1.5',
   '14.2',
   '15.6',
   '25.4',
   '25.8',
   '8.1'},
  '场均人次': {'25.0', '3.0', '4.0', '5.0', '6.0', '7.0'}},
 'all_values': {'0.9',
  '1.2',
  '1.4',
  '1.5',
  '10503.8',
  '10637.3',
  '14.2',
  '15.6',
  '25.0',
  '25.4',
  '25.8',
  '3.0',
  '3322.9',
  '356.6',
  '360.0',
  '4.0',
  '5.0',
  '500.3',
  '5841.4',
  '595.5',
  '6.0',
  '635.2',
  '6426.6',
  '7.0',
  '8.1',
  '“大”人物',
  '一条狗的回家路',
  '大黄蜂',
  '家和万事惊',
  '密室逃生',
  '掠食城市',
  '死侍2:我爱我家',
  '海王',
  '白蛇:缘起',
  '钢铁飞龙之奥特曼崛起'}}
train_tables['4d25e6403aaa11e9bdbbf40f24344a08']
{'headers': ['证券代码',
  '证券简称',
  '最新收盘价',
  '定增价除权后至今价格',
  '增发价格',
  '倒挂率',
  '定增年度',
  '增发目的'],
 'header2id': {'证券代码': 0,
  '证券简称': 1,
  '最新收盘价': 2,
  '定增价除权后至今价格': 3,
  '增发价格': 4,
  '倒挂率': 5,
  '定增年度': 6,
  '增发目的': 7},
 'content': {'证券代码': {'300148.SZ', '300182.SZ', '300269.SZ'},
  '证券简称': {'天舟文化', '捷成股份', '联建光电'},
  '最新收盘价': {'4.09', '4.69', '5.48'},
  '定增价除权后至今价格': {'11.16', '11.29', '12.48', '21.88', '23.07', '9.91'},
  '增发价格': {'14.78', '15.09', '16.34', '16.988', '22.09', '23.3004'},
  '倒挂率': {'23.75', '25.05', '36.65', '37.58', '41.26', '41.54'},
  '定增年度': {'2016.0'},
  '增发目的': {'融资收购其他资产', '配套融资'}},
 'all_values': {'11.16',
  '11.29',
  '12.48',
  '14.78',
  '15.09',
  '16.34',
  '16.988',
  '2016.0',
  '21.88',
  '22.09',
  '23.07',
  '23.3004',
  '23.75',
  '25.05',
  '300148.SZ',
  '300182.SZ',
  '300269.SZ',
  '36.65',
  '37.58',
  '4.09',
  '4.69',
  '41.26',
  '41.54',
  '5.48',
  '9.91',
  '天舟文化',
  '捷成股份',
  '联建光电',
  '融资收购其他资产',
  '配套融资'}}
# 对每个汉字进行编码
# 读取词表,并给创建每一个字对应的序号的字典
token_dict = {}

with codecs.open(dict_path, 'r', 'utf8') as reader:
    for line in reader:
        token = line.strip()
        token_dict[token] = len(token_dict)
token_dict
{'[PAD]': 0,
 '[unused1]': 1,
 '[unused2]': 2,
 '[unused3]': 3,
 '[unused4]': 4,
 '[unused5]': 5,
 '[unused6]': 6,
 '[unused7]': 7,
 '[unused8]': 8,
 '[unused9]': 9,
 '[unused10]': 10,
 '[unused11]': 11,
 '[unused12]': 12,
 '[unused13]': 13,
 '[unused14]': 14,
 '[unused15]': 15,
 '[unused16]': 16,
 '[unused17]': 17,
 '[unused18]': 18,
 '[unused19]': 19,
 '[unused20]': 20,
 '[unused21]': 21,
 '[unused22]': 22,
 '[unused23]': 23,
 '[unused24]': 24,
 '[unused25]': 25,
 '[unused26]': 26,
 '[unused27]': 27,
 '[unused28]': 28,
 '[unused29]': 29,
 '[unused30]': 30,
 '[unused31]': 31,
 '[unused32]': 32,
 '[unused33]': 33,
 '[unused34]': 34,
 '[unused35]': 35,
 '[unused36]': 36,
 '[unused37]': 37,
 '[unused38]': 38,
 '[unused39]': 39,
 '[unused40]': 40,
 '[unused41]': 41,
 '[unused42]': 42,
 '[unused43]': 43,
 '[unused44]': 44,
 '[unused45]': 45,
 '[unused46]': 46,
 '[unused47]': 47,
 '[unused48]': 48,
 '[unused49]': 49,
 '[unused50]': 50,
 '[unused51]': 51,
 '[unused52]': 52,
 '[unused53]': 53,
 '[unused54]': 54,
 '[unused55]': 55,
 '[unused56]': 56,
 '[unused57]': 57,
 '[unused58]': 58,
 '[unused59]': 59,
 '[unused60]': 60,
 '[unused61]': 61,
 '[unused62]': 62,
 '[unused63]': 63,
 '[unused64]': 64,
 '[unused65]': 65,
 '[unused66]': 66,
 '[unused67]': 67,
 '[unused68]': 68,
 '[unused69]': 69,
 '[unused70]': 70,
 '[unused71]': 71,
 '[unused72]': 72,
 '[unused73]': 73,
 '[unused74]': 74,
 '[unused75]': 75,
 '[unused76]': 76,
 '[unused77]': 77,
 '[unused78]': 78,
 '[unused79]': 79,
 '[unused80]': 80,
 '[unused81]': 81,
 '[unused82]': 82,
 '[unused83]': 83,
 '[unused84]': 84,
 '[unused85]': 85,
 '[unused86]': 86,
 '[unused87]': 87,
 '[unused88]': 88,
 '[unused89]': 89,
 '[unused90]': 90,
 '[unused91]': 91,
 '[unused92]': 92,
 '[unused93]': 93,
 '[unused94]': 94,
 '[unused95]': 95,
 '[unused96]': 96,
 '[unused97]': 97,
 '[unused98]': 98,
 '[unused99]': 99,
 '[UNK]': 100,
 '[CLS]': 101,
 '[SEP]': 102,
 '[MASK]': 103,
 '<S>': 104,
 '<T>': 105,
 '!': 106,
 '"': 107,
 '#': 108,
 '$': 109,
 '%': 110,
 '&': 111,
 "'": 112,
 '(': 113,
 ')': 114,
 '*': 115,
 '+': 116,
 ',': 117,
 '-': 118,
 '.': 119,
 '/': 120,
 '0': 121,
 '1': 122,
 '2': 123,
 '3': 124,
 '4': 125,
 '5': 126,
 '6': 127,
 '7': 128,
 '8': 129,
 '9': 130,
 ':': 131,
 ';': 132,
 '<': 133,
 '=': 134,
 '>': 135,
 '?': 136,
 '@': 137,
 '[': 138,
 '\\': 139,
 ']': 140,
 '^': 141,
 '_': 142,
 'a': 143,
 'b': 144,
 'c': 145,
 'd': 146,
 'e': 147,
 'f': 148,
 'g': 149,
 'h': 150,
 'i': 151,
 'j': 152,
 'k': 153,
 'l': 154,
 'm': 155,
 'n': 156,
 'o': 157,
 'p': 158,
 'q': 159,
 'r': 160,
 's': 161,
 't': 162,
 'u': 163,
 'v': 164,
 'w': 165,
 'x': 166,
 'y': 167,
 'z': 168,
 '{': 169,
 '|': 170,
 '}': 171,
 '~': 172,
 '£': 173,
 '¤': 174,
 '¥': 175,
 '§': 176,
 '©': 177,
 '«': 178,
 '®': 179,
 '°': 180,
 '±': 181,
 '²': 182,
 '³': 183,
 'µ': 184,
 '·': 185,
 '¹': 186,
 'º': 187,
 '»': 188,
 '¼': 189,
 '×': 190,
 'ß': 191,
 'æ': 192,
 '÷': 193,
 'ø': 194,
 'đ': 195,
 'ŋ': 196,
 'ɔ': 197,
 'ə': 198,
 'ɡ': 199,
 'ʰ': 200,
 'ˇ': 201,
 'ˈ': 202,
 'ˊ': 203,
 'ˋ': 204,
 'ˍ': 205,
 'ː': 206,
 '˙': 207,
 '˚': 208,
 'ˢ': 209,
 'α': 210,
 'β': 211,
 'γ': 212,
 'δ': 213,
 'ε': 214,
 'η': 215,
 'θ': 216,
 'ι': 217,
 'κ': 218,
 'λ': 219,
 'μ': 220,
 'ν': 221,
 'ο': 222,
 'π': 223,
 'ρ': 224,
 'ς': 225,
 'σ': 226,
 'τ': 227,
 'υ': 228,
 'φ': 229,
 'χ': 230,
 'ψ': 231,
 'ω': 232,
 'а': 233,
 'б': 234,
 'в': 235,
 'г': 236,
 'д': 237,
 'е': 238,
 'ж': 239,
 'з': 240,
 'и': 241,
 'к': 242,
 'л': 243,
 'м': 244,
 'н': 245,
 'о': 246,
 'п': 247,
 'р': 248,
 'с': 249,
 'т': 250,
 'у': 251,
 'ф': 252,
 'х': 253,
 'ц': 254,
 'ч': 255,
 'ш': 256,
 'ы': 257,
 'ь': 258,
 'я': 259,
 'і': 260,
 'ا': 261,
 'ب': 262,
 'ة': 263,
 'ت': 264,
 'د': 265,
 'ر': 266,
 'س': 267,
 'ع': 268,
 'ل': 269,
 'م': 270,
 'ن': 271,
 'ه': 272,
 'و': 273,
 'ي': 274,
 '۩': 275,
 'ก': 276,
 'ง': 277,
 'น': 278,
 'ม': 279,
 'ย': 280,
 'ร': 281,
 'อ': 282,
 'า': 283,
 'เ': 284,
 '๑': 285,
 '་': 286,
 'ღ': 287,
 'ᄀ': 288,
 'ᄁ': 289,
 'ᄂ': 290,
 'ᄃ': 291,
 'ᄅ': 292,
 'ᄆ': 293,
 'ᄇ': 294,
 'ᄈ': 295,
 'ᄉ': 296,
 'ᄋ': 297,
 'ᄌ': 298,
 'ᄎ': 299,
 'ᄏ': 300,
 'ᄐ': 301,
 'ᄑ': 302,
 'ᄒ': 303,
 'ᅡ': 304,
 'ᅢ': 305,
 'ᅣ': 306,
 'ᅥ': 307,
 'ᅦ': 308,
 'ᅧ': 309,
 'ᅨ': 310,
 'ᅩ': 311,
 'ᅪ': 312,
 'ᅬ': 313,
 'ᅭ': 314,
 'ᅮ': 315,
 'ᅯ': 316,
 'ᅲ': 317,
 'ᅳ': 318,
 'ᅴ': 319,
 'ᅵ': 320,
 'ᆨ': 321,
 'ᆫ': 322,
 'ᆯ': 323,
 'ᆷ': 324,
 'ᆸ': 325,
 'ᆺ': 326,
 'ᆻ': 327,
 'ᆼ': 328,
 'ᗜ': 329,
 'ᵃ': 330,
 'ᵉ': 331,
 'ᵍ': 332,
 'ᵏ': 333,
 'ᵐ': 334,
 'ᵒ': 335,
 'ᵘ': 336,
 '‖': 337,
 '„': 338,
 '†': 339,
 '•': 340,
 '‥': 341,
 '‧': 342,
 '': 13503,
 '‰': 344,
 '′': 345,
 '″': 346,
 '‹': 347,
 '›': 348,
 '※': 349,
 '‿': 350,
 '⁄': 351,
 'ⁱ': 352,
 '⁺': 353,
 'ⁿ': 354,
 '₁': 355,
 '₂': 356,
 '₃': 357,
 '₄': 358,
 '€': 359,
 '℃': 360,
 '№': 361,
 '™': 362,
 'ⅰ': 363,
 'ⅱ': 364,
 'ⅲ': 365,
 'ⅳ': 366,
 'ⅴ': 367,
 '←': 368,
 '↑': 369,
 '→': 370,
 '↓': 371,
 '↔': 372,
 '↗': 373,
 '↘': 374,
 '⇒': 375,
 '∀': 376,
 '−': 377,
 '∕': 378,
 '∙': 379,
 '√': 380,
 '∞': 381,
 '∟': 382,
 '∠': 383,
 '∣': 384,
 '∥': 385,
 '∩': 386,
 '∮': 387,
 '∶': 388,
 '∼': 389,
 '∽': 390,
 '≈': 391,
 '≒': 392,
 '≡': 393,
 '≤': 394,
 '≥': 395,
 '≦': 396,
 '≧': 397,
 '≪': 398,
 '≫': 399,
 '⊙': 400,
 '⋅': 401,
 '⋈': 402,
 '⋯': 403,
 '⌒': 404,
 '①': 405,
 '②': 406,
 '③': 407,
 '④': 408,
 '⑤': 409,
 '⑥': 410,
 '⑦': 411,
 '⑧': 412,
 '⑨': 413,
 '⑩': 414,
 '⑴': 415,
 '⑵': 416,
 '⑶': 417,
 '⑷': 418,
 '⑸': 419,
 '⒈': 420,
 '⒉': 421,
 '⒊': 422,
 '⒋': 423,
 'ⓒ': 424,
 'ⓔ': 425,
 'ⓘ': 426,
 '─': 427,
 '━': 428,
 '│': 429,
 '┃': 430,
 '┅': 431,
 '┆': 432,
 '┊': 433,
 '┌': 434,
 '└': 435,
 '├': 436,
 '┣': 437,
 '═': 438,
 '║': 439,
 '╚': 440,
 '╞': 441,
 '╠': 442,
 '╭': 443,
 '╮': 444,
 '╯': 445,
 '╰': 446,
 '╱': 447,
 '╳': 448,
 '▂': 449,
 '▃': 450,
 '▅': 451,
 '▇': 452,
 '█': 453,
 '▉': 454,
 '▋': 455,
 '▌': 456,
 '▍': 457,
 '▎': 458,
 '■': 459,
 '□': 460,
 '▪': 461,
 '▫': 462,
 '▬': 463,
 '▲': 464,
 '△': 465,
 '▶': 466,
 '►': 467,
 '▼': 468,
 '▽': 469,
 '◆': 470,
 '◇': 471,
 '○': 472,
 '◎': 473,
 '●': 474,
 '◕': 475,
 '◠': 476,
 '◢': 477,
 '◤': 478,
 '☀': 479,
 '★': 480,
 '☆': 481,
 '☕': 482,
 '☞': 483,
 '☺': 484,
 '☼': 485,
 '♀': 486,
 '♂': 487,
 '♠': 488,
 '♡': 489,
 '♣': 490,
 '♥': 491,
 '♦': 492,
 '♪': 493,
 '♫': 494,
 '♬': 495,
 '✈': 496,
 '✔': 497,
 '✕': 498,
 '✖': 499,
 '✦': 500,
 '✨': 501,
 '✪': 502,
 '✰': 503,
 '✿': 504,
 '❀': 505,
 '❤': 506,
 '➜': 507,
 '➤': 508,
 '⦿': 509,
 '、': 510,
 '。': 511,
 '〃': 512,
 '々': 513,
 '〇': 514,
 '〈': 515,
 '〉': 516,
 '《': 517,
 '》': 518,
 '「': 519,
 '」': 520,
 '『': 521,
 '』': 522,
 '【': 523,
 '】': 524,
 '〓': 525,
 '〔': 526,
 '〕': 527,
 '〖': 528,
 '〗': 529,
 '〜': 530,
 '〝': 531,
 '〞': 532,
 'ぁ': 533,
 'あ': 534,
 'ぃ': 535,
 'い': 536,
 'う': 537,
 'ぇ': 538,
 'え': 539,
 'お': 540,
 'か': 541,
 'き': 542,
 'く': 543,
 'け': 544,
 'こ': 545,
 'さ': 546,
 'し': 547,
 'す': 548,
 'せ': 549,
 'そ': 550,
 'た': 551,
 'ち': 552,
 'っ': 553,
 'つ': 554,
 'て': 555,
 'と': 556,
 'な': 557,
 'に': 558,
 'ぬ': 559,
 'ね': 560,
 'の': 561,
 'は': 562,
 'ひ': 563,
 'ふ': 564,
 'へ': 565,
 'ほ': 566,
 'ま': 567,
 'み': 568,
 'む': 569,
 'め': 570,
 'も': 571,
 'ゃ': 572,
 'や': 573,
 'ゅ': 574,
 'ゆ': 575,
 'ょ': 576,
 'よ': 577,
 'ら': 578,
 'り': 579,
 'る': 580,
 'れ': 581,
 'ろ': 582,
 'わ': 583,
 'を': 584,
 'ん': 585,
 '゜': 586,
 'ゝ': 587,
 'ァ': 588,
 'ア': 589,
 'ィ': 590,
 'イ': 591,
 'ゥ': 592,
 'ウ': 593,
 'ェ': 594,
 'エ': 595,
 'ォ': 596,
 'オ': 597,
 'カ': 598,
 'キ': 599,
 'ク': 600,
 'ケ': 601,
 'コ': 602,
 'サ': 603,
 'シ': 604,
 'ス': 605,
 'セ': 606,
 'ソ': 607,
 'タ': 608,
 'チ': 609,
 'ッ': 610,
 'ツ': 611,
 'テ': 612,
 'ト': 613,
 'ナ': 614,
 'ニ': 615,
 'ヌ': 616,
 'ネ': 617,
 'ノ': 618,
 'ハ': 619,
 'ヒ': 620,
 'フ': 621,
 'ヘ': 622,
 'ホ': 623,
 'マ': 624,
 'ミ': 625,
 'ム': 626,
 'メ': 627,
 'モ': 628,
 'ャ': 629,
 'ヤ': 630,
 'ュ': 631,
 'ユ': 632,
 'ョ': 633,
 'ヨ': 634,
 'ラ': 635,
 'リ': 636,
 'ル': 637,
 'レ': 638,
 'ロ': 639,
 'ワ': 640,
 'ヲ': 641,
 'ン': 642,
 'ヶ': 643,
 '・': 644,
 'ー': 645,
 'ヽ': 646,
 'ㄅ': 647,
 'ㄆ': 648,
 'ㄇ': 649,
 'ㄉ': 650,
 'ㄋ': 651,
 'ㄌ': 652,
 'ㄍ': 653,
 'ㄎ': 654,
 'ㄏ': 655,
 'ㄒ': 656,
 'ㄚ': 657,
 'ㄛ': 658,
 'ㄞ': 659,
 'ㄟ': 660,
 'ㄢ': 661,
 'ㄤ': 662,
 'ㄥ': 663,
 'ㄧ': 664,
 'ㄨ': 665,
 'ㆍ': 666,
 '㈦': 667,
 '㊣': 668,
 '㎡': 669,
 '㗎': 670,
 '一': 671,
 '丁': 672,
 '七': 673,
 '万': 674,
 '丈': 675,
 '三': 676,
 '上': 677,
 '下': 678,
 '不': 679,
 '与': 680,
 '丐': 681,
 '丑': 682,
 '专': 683,
 '且': 684,
 '丕': 685,
 '世': 686,
 '丘': 687,
 '丙': 688,
 '业': 689,
 '丛': 690,
 '东': 691,
 '丝': 692,
 '丞': 693,
 '丟': 694,
 '両': 695,
 '丢': 696,
 '两': 697,
 '严': 698,
 '並': 699,
 '丧': 700,
 '丨': 701,
 '个': 702,
 '丫': 703,
 '中': 704,
 '丰': 705,
 '串': 706,
 '临': 707,
 '丶': 708,
 '丸': 709,
 '丹': 710,
 '为': 711,
 '主': 712,
 '丼': 713,
 '丽': 714,
 '举': 715,
 '丿': 716,
 '乂': 717,
 '乃': 718,
 '久': 719,
 '么': 720,
 '义': 721,
 '之': 722,
 '乌': 723,
 '乍': 724,
 '乎': 725,
 '乏': 726,
 '乐': 727,
 '乒': 728,
 '乓': 729,
 '乔': 730,
 '乖': 731,
 '乗': 732,
 '乘': 733,
 '乙': 734,
 '乜': 735,
 '九': 736,
 '乞': 737,
 '也': 738,
 '习': 739,
 '乡': 740,
 '书': 741,
 '乩': 742,
 '买': 743,
 '乱': 744,
 '乳': 745,
 '乾': 746,
 '亀': 747,
 '亂': 748,
 '了': 749,
 '予': 750,
 '争': 751,
 '事': 752,
 '二': 753,
 '于': 754,
 '亏': 755,
 '云': 756,
 '互': 757,
 '五': 758,
 '井': 759,
 '亘': 760,
 '亙': 761,
 '亚': 762,
 '些': 763,
 '亜': 764,
 '亞': 765,
 '亟': 766,
 '亡': 767,
 '亢': 768,
 '交': 769,
 '亥': 770,
 '亦': 771,
 '产': 772,
 '亨': 773,
 '亩': 774,
 '享': 775,
 '京': 776,
 '亭': 777,
 '亮': 778,
 '亲': 779,
 '亳': 780,
 '亵': 781,
 '人': 782,
 '亿': 783,
 '什': 784,
 '仁': 785,
 '仃': 786,
 '仄': 787,
 '仅': 788,
 '仆': 789,
 '仇': 790,
 '今': 791,
 '介': 792,
 '仍': 793,
 '从': 794,
 '仏': 795,
 '仑': 796,
 '仓': 797,
 '仔': 798,
 '仕': 799,
 '他': 800,
 '仗': 801,
 '付': 802,
 '仙': 803,
 '仝': 804,
 '仞': 805,
 '仟': 806,
 '代': 807,
 '令': 808,
 '以': 809,
 '仨': 810,
 '仪': 811,
 '们': 812,
 '仮': 813,
 '仰': 814,
 '仲': 815,
 '件': 816,
 '价': 817,
 '任': 818,
 '份': 819,
 '仿': 820,
 '企': 821,
 '伉': 822,
 '伊': 823,
 '伍': 824,
 '伎': 825,
 '伏': 826,
 '伐': 827,
 '休': 828,
 '伕': 829,
 '众': 830,
 '优': 831,
 '伙': 832,
 '会': 833,
 '伝': 834,
 '伞': 835,
 '伟': 836,
 '传': 837,
 '伢': 838,
 '伤': 839,
 '伦': 840,
 '伪': 841,
 '伫': 842,
 '伯': 843,
 '估': 844,
 '伴': 845,
 '伶': 846,
 '伸': 847,
 '伺': 848,
 '似': 849,
 '伽': 850,
 '佃': 851,
 '但': 852,
 '佇': 853,
 '佈': 854,
 '位': 855,
 '低': 856,
 '住': 857,
 '佐': 858,
 '佑': 859,
 '体': 860,
 '佔': 861,
 '何': 862,
 '佗': 863,
 '佘': 864,
 '余': 865,
 '佚': 866,
 '佛': 867,
 '作': 868,
 '佝': 869,
 '佞': 870,
 '佟': 871,
 '你': 872,
 '佢': 873,
 '佣': 874,
 '佤': 875,
 '佥': 876,
 '佩': 877,
 '佬': 878,
 '佯': 879,
 '佰': 880,
 '佳': 881,
 '併': 882,
 '佶': 883,
 '佻': 884,
 '佼': 885,
 '使': 886,
 '侃': 887,
 '侄': 888,
 '來': 889,
 '侈': 890,
 '例': 891,
 '侍': 892,
 '侏': 893,
 '侑': 894,
 '侖': 895,
 '侗': 896,
 '供': 897,
 '依': 898,
 '侠': 899,
 '価': 900,
 '侣': 901,
 '侥': 902,
 '侦': 903,
 '侧': 904,
 '侨': 905,
 '侬': 906,
 '侮': 907,
 '侯': 908,
 '侵': 909,
 '侶': 910,
 '侷': 911,
 '便': 912,
 '係': 913,
 '促': 914,
 '俄': 915,
 '俊': 916,
 '俎': 917,
 '俏': 918,
 '俐': 919,
 '俑': 920,
 '俗': 921,
 '俘': 922,
 '俚': 923,
 '保': 924,
 '俞': 925,
 '俟': 926,
 '俠': 927,
 '信': 928,
 '俨': 929,
 '俩': 930,
 '俪': 931,
 '俬': 932,
 '俭': 933,
 '修': 934,
 '俯': 935,
 '俱': 936,
 '俳': 937,
 '俸': 938,
 '俺': 939,
 '俾': 940,
 '倆': 941,
 '倉': 942,
 '個': 943,
 '倌': 944,
 '倍': 945,
 '倏': 946,
 '們': 947,
 '倒': 948,
 '倔': 949,
 '倖': 950,
 '倘': 951,
 '候': 952,
 '倚': 953,
 '倜': 954,
 '借': 955,
 '倡': 956,
 '値': 957,
 '倦': 958,
 '倩': 959,
 '倪': 960,
 '倫': 961,
 '倬': 962,
 '倭': 963,
 '倶': 964,
 '债': 965,
 '值': 966,
 '倾': 967,
 '偃': 968,
 '假': 969,
 '偈': 970,
 '偉': 971,
 '偌': 972,
 '偎': 973,
 '偏': 974,
 '偕': 975,
 '做': 976,
 '停': 977,
 '健': 978,
 '側': 979,
 '偵': 980,
 '偶': 981,
 '偷': 982,
 '偻': 983,
 '偽': 984,
 '偿': 985,
 '傀': 986,
 '傅': 987,
 '傍': 988,
 '傑': 989,
 '傘': 990,
 '備': 991,
 '傚': 992,
 '傢': 993,
 '傣': 994,
 '傥': 995,
 '储': 996,
 '傩': 997,
 '催': 998,
 '傭': 999,
 ...}
# 重写Tokenizer(分词用),为了保证text经过tokenizer后与原text长度相同
# Tokenizer 自带的 _tokenize 会自动去掉空格,然后有些字符会粘在一块输出,
# 导致 tokenize 之后的列表不等于原来字符串的长度了,这样如果做序列标注的任务会很麻烦

# 继承Tokenizer类
class OurTokenizer(Tokenizer):
    def _tokenize(self, text):
        R = []
        for c in text:
            if c in self._token_dict:
                R.append(c)
            elif self._is_space(c):
                R.append('[unused1]') # space类用未经训练的[unused1]表示
            else:
                R.append('[UNK]') # 剩余的字符是[UNK]
        return R

# 输入词表形成分词器
tokenizer = OurTokenizer(token_dict)
tokenizer
<__main__.OurTokenizer at 0x2430e637908>
# 句子填充
# 将句子都补充为等长
def seq_padding(X, padding=0, maxlen=None):
    if maxlen is None:
        L = [len(x) for x in X]#获取每句话的长度
        ML = max(L)#获取最长句子的长度
    else:
        ML = maxlen
    return np.array([
        np.concatenate([x[:ML], [padding] * (ML - len(x))]) if len(x[:ML]) < ML else x for x in X
    ])              #np.cpncatenate会将几个矩阵进行拼接,如果x的长度小于ML会用0进行填充,如果x长度大于ML则不做处理

# for x in X:
#     if len(x[:ML]) < ML:
#         np.concatenate([x[:ML], [padding] * (ML - len(x))])
#     else:
#         x
def most_similar(s, slist):
    """从词表中找最相近的词(当无法全匹配的时候)
    """
    if len(slist) == 0:
        return s
    scores = [editdistance.eval(s, t) for t in slist]#最小编辑距离算法
    return slist[np.argmin(scores)]
def most_similar_2(w, s):
    """从句子s中找与w最相近的片段,
    借助分词工具和ngram的方式尽量精确地确定边界。
    """
    sw = jieba.lcut(s)
    sl = list(sw)
    sl.extend([''.join(i) for i in zip(sw, sw[1:])])
    sl.extend([''.join(i) for i in zip(sw, sw[1:], sw[2:])])
    return most_similar(w, sl)
d=train_data[0]
# 对"二零一九年"对照编码词表进行编码,并加上了前后的[cls]和[sep]
# x1是词的编码信息,x2是说明每个数字属于第几句话
x1, x2 = tokenizer.encode('二零一九年',"我是傻子")
print(x1)
print(x2)
print(len(x1)) #bert的输入除了单词的向量外还需要有position vector还需要有segment 
print(len(x2))
[101, 753, 7439, 671, 736, 2399, 102, 2769, 3221, 1004, 2094, 102]
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
12
12
# 产生数据
class data_generator:
    def __init__(self, data, tables, batch_size=32):
        # 设置batch_size和steps
        self.data = data
        self.tables = tables
        self.batch_size = batch_size
        self.steps = len(self.data) // self.batch_size
        if len(self.data) % self.batch_size != 0:
            self.steps += 1
    def __len__(self):
        return self.steps
    def __iter__(self):
        while True:
            idxs = list(range(len(self.data)))
            np.random.shuffle(idxs)
            X1, X2, XM, H, HM, SEL, CONN, CSEL, COP = [], [], [], [], [], [], [], [], []
            # 遍历每一个输入数据 -> 字典:包括句子及其id和sql的'agg','cond_conn_op','sel','conds'
            for i in idxs:
                d = self.data[i]
                # 去table中查找该id对应的列名
                t = self.tables[d['table_id']]['headers']
                # 将输入数据的question编码
                x1, x2 = tokenizer.encode(d['question'])
                # 设置一个与输入数据编码后长度相同的列表[0,1,1,...,1,0]
                xm = [0] + [1] * len(d['question']) + [0]
                
                h = []
                for j in t:
                    # 对列名进行编码
                    _x1, _x2 = tokenizer.encode(j)
                    # len(h)记录了有多少个列名
                    h.append(len(x1))
                    # extend() 函数用于在列表末尾一次性追加另一个序列中的多个值
                    # 将问题编码与列名编码合并
                    x1.extend(_x1)
                    x2.extend(_x2)
                # 列名个1
                hm = [1] * len(h)
                
                sel = []
                for j in range(len(h)):
                    # index() 方法检测字符串中是否包含子字符串 str ,并返回索引值
                    # 如果j是sel中的,则获得这个sel的索引值并赋给j
                    if j in d['sql']['sel']:
                        j = d['sql']['sel'].index(j)
                        sel.append(d['sql']['agg'][j])
                    else:
                        sel.append(num_agg - 1) # 不被select则被标记为num_agg-1
                # 获得and 或 or
                conn = [d['sql']['cond_conn_op']]
                csel = np.zeros(len(d['question']) + 2, dtype='int32') # 这里的0既表示padding,又表示第一列,padding部分训练时会被mask
                cop = np.zeros(len(d['question']) + 2, dtype='int32') + num_op - 1 # 不被select则被标记为num_op-1
                for j in d['sql']['conds']:
                    if j[2] not in d['question']:
                        j[2] = most_similar_2(j[2], d['question'])
                    if j[2] not in d['question']:
                        continue
                    k = d['question'].index(j[2])
                    csel[k + 1: k + 1 + len(j[2])] = j[0]
                    cop[k + 1: k + 1 + len(j[2])] = j[1]
                if len(x1) > maxlen:
                    continue
                X1.append(x1) # bert的输入
                X2.append(x2) # bert的输入
                XM.append(xm) # 输入序列的mask
                H.append(h) # 列名所在位置
                HM.append(hm) # 列名mask
                SEL.append(sel) # 被select的列
                CONN.append(conn) # 连接类型
                CSEL.append(csel) # 条件中的列
                COP.append(cop) # 条件中的运算符(同时也是值的标记)
                if len(X1) == self.batch_size:
                    X1 = seq_padding(X1)
                    X2 = seq_padding(X2)
                    XM = seq_padding(XM, maxlen=X1.shape[1])
                    H = seq_padding(H)
                    HM = seq_padding(HM)
                    SEL = seq_padding(SEL)
                    CONN = seq_padding(CONN)
                    CSEL = seq_padding(CSEL, maxlen=X1.shape[1])
                    COP = seq_padding(COP, maxlen=X1.shape[1])
                    yield [X1, X2, XM, H, HM, SEL, CONN, CSEL, COP], None
                    X1, X2, XM, H, HM, SEL, CONN, CSEL, COP = [], [], [], [], [], [], [], [], []
def seq_gather(x):
    """seq是[None, seq_len, s_size]的格式,
    idxs是[None, n]的格式,在seq的第i个序列中选出第idxs[i]个向量,
    最终输出[None, n, s_size]的向量。
    """
    seq, idxs = x
    idxs = K.cast(idxs, 'int32')
    return K.tf.batch_gather(seq, idxs)
bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\keras\backend\tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
for l in bert_model.layers:
    l.trainable = True
x1_in = Input(shape=(None,), dtype='int32')
x2_in = Input(shape=(None,))
xm_in = Input(shape=(None,))
h_in = Input(shape=(None,), dtype='int32')
hm_in = Input(shape=(None,))
sel_in = Input(shape=(None,), dtype='int32')
conn_in = Input(shape=(1,), dtype='int32')
csel_in = Input(shape=(None,), dtype='int32')
cop_in = Input(shape=(None,), dtype='int32')
x1, x2, xm, h, hm, sel, conn, csel, cop = (
    x1_in, x2_in, xm_in, h_in, hm_in, sel_in, conn_in, csel_in, cop_in
)
hm = Lambda(lambda x: K.expand_dims(x, 1))(hm) # header的mask.shape=(None, 1, h_len)

x = bert_model([x1_in, x2_in])
x4conn = Lambda(lambda x: x[:, 0])(x)
pconn = Dense(num_cond_conn_op, activation='softmax')(x4conn)

x4h = Lambda(seq_gather)([x, h])
psel = Dense(num_agg, activation='softmax')(x4h)

pcop = Dense(num_op, activation='softmax')(x)

x = Lambda(lambda x: K.expand_dims(x, 2))(x)
x4h = Lambda(lambda x: K.expand_dims(x, 1))(x4h)
pcsel_1 = Dense(256)(x)
pcsel_2 = Dense(256)(x4h)
pcsel = Lambda(lambda x: x[0] + x[1])([pcsel_1, pcsel_2])
pcsel = Activation('tanh')(pcsel)
pcsel = Dense(1)(pcsel)
pcsel = Lambda(lambda x: x[0][..., 0] - (1 - x[1]) * 1e10)([pcsel, hm])
pcsel = Activation('softmax')(pcsel)
model = Model(
    [x1_in, x2_in, h_in, hm_in],
    [psel, pconn, pcop, pcsel]
)

train_model = Model(
    [x1_in, x2_in, xm_in, h_in, hm_in, sel_in, conn_in, csel_in, cop_in],
    [psel, pconn, pcop, pcsel]
)
xm = xm # question的mask.shape=(None, x_len)
hm = hm[:, 0] # header的mask.shape=(None, h_len)
cm = K.cast(K.not_equal(cop, num_op - 1), 'float32') # conds的mask.shape=(None, x_len)

psel_loss = K.sparse_categorical_crossentropy(sel_in, psel)
psel_loss = K.sum(psel_loss * hm) / K.sum(hm)
pconn_loss = K.sparse_categorical_crossentropy(conn_in, pconn)
pconn_loss = K.mean(pconn_loss)
pcop_loss = K.sparse_categorical_crossentropy(cop_in, pcop)
pcop_loss = K.sum(pcop_loss * xm) / K.sum(xm)
pcsel_loss = K.sparse_categorical_crossentropy(csel_in, pcsel)
pcsel_loss = K.sum(pcsel_loss * xm * cm) / K.sum(xm * cm)
loss = psel_loss + pconn_loss + pcop_loss + pcsel_loss

train_model.add_loss(loss)
train_model.compile(optimizer=Adam(learning_rate))
train_model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
model_2 (Model)                 (None, None, 768)    101677056   input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
input_4 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, None, 768)    0           model_2[1][0]                    
                                                                 input_4[0][0]                    
__________________________________________________________________________________________________
lambda_4 (Lambda)               (None, None, 1, 768) 0           model_2[1][0]                    
__________________________________________________________________________________________________
lambda_5 (Lambda)               (None, 1, None, 768) 0           lambda_3[0][0]                   
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, None, 1, 256) 196864      lambda_4[0][0]                   
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 1, None, 256) 196864      lambda_5[0][0]                   
__________________________________________________________________________________________________
lambda_6 (Lambda)               (None, None, None, 2 0           dense_4[0][0]                    
                                                                 dense_5[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None, 2 0           lambda_6[0][0]                   
__________________________________________________________________________________________________
input_5 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, None, None, 1 257         activation_1[0][0]               
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 1, None)      0           input_5[0][0]                    
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 768)          0           model_2[1][0]                    
__________________________________________________________________________________________________
lambda_7 (Lambda)               (None, None, None)   0           dense_6[0][0]                    
                                                                 lambda_1[0][0]                   
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, None, 7)      5383        lambda_3[0][0]                   
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 3)            2307        lambda_2[0][0]                   
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, None, 5)      3845        model_2[1][0]                    
__________________________________________________________________________________________________
activation_2 (Activation)       (None, None, None)   0           lambda_7[0][0]                   
==================================================================================================
Total params: 102,082,576
Trainable params: 102,082,576
Non-trainable params: 0
__________________________________________________________________________________________________
def nl2sql(question, table):
    """输入question和headers,转SQL
    """
    x1, x2 = tokenizer.encode(question)
    h = []
    for i in table['headers']:
        _x1, _x2 = tokenizer.encode(i)
        h.append(len(x1))
        x1.extend(_x1)
        x2.extend(_x2)
    hm = [1] * len(h)
    psel, pconn, pcop, pcsel = model.predict([
        np.array([x1]),
        np.array([x2]),
        np.array([h]),
        np.array([hm])
    ])
    R = {'agg': [], 'sel': []}
    for i, j in enumerate(psel[0].argmax(1)):
        if j != num_agg - 1: # num_agg-1类是不被select的意思
            R['sel'].append(i)
            R['agg'].append(j)
    conds = []
    v_op = -1
    for i, j in enumerate(pcop[0, :len(question)+1].argmax(1)):
        # 这里结合标注和分类来预测条件
        if j != num_op - 1:
            if v_op != j:
                if v_op != -1:
                    v_end = v_start + len(v_str)
                    csel = pcsel[0][v_start: v_end].mean(0).argmax()
                    conds.append((csel, v_op, v_str))
                v_start = i
                v_op = j
                v_str = question[i - 1]
            else:
                v_str += question[i - 1]
        elif v_op != -1:
            v_end = v_start + len(v_str)
            csel = pcsel[0][v_start: v_end].mean(0).argmax()
            conds.append((csel, v_op, v_str))
            v_op = -1
    R['conds'] = set()
    for i, j, k in conds:
        if re.findall('[^\d\.]', k):
            j = 2 # 非数字只能用等号
        if j == 2:
            if k not in table['all_values']:
                # 等号的值必须在table出现过,否则找一个最相近的
                k = most_similar(k, list(table['all_values']))
            h = table['headers'][i]
            # 然后检查值对应的列是否正确,如果不正确,直接修正列名
            if k not in table['content'][h]:
                for r, v in table['content'].items():
                    if k in v:
                        i = table['header2id'][r]
                        break
        R['conds'].add((i, j, k))
    R['conds'] = list(R['conds'])
    if len(R['conds']) <= 1: # 条件数少于等于1时,条件连接符直接为0
        R['cond_conn_op'] = 0
    else:
        R['cond_conn_op'] = 1 + pconn[0, 1:].argmax() # 不能是0
    return R
def is_equal(R1, R2):
    """判断两个SQL字典是否全匹配
    """
    return (R1['cond_conn_op'] == R2['cond_conn_op']) &\
    (set(zip(R1['sel'], R1['agg'])) == set(zip(R2['sel'], R2['agg']))) &\
    (set([tuple(i) for i in R1['conds']]) == set([tuple(i) for i in R2['conds']]))
def evaluate(data, tables):
    right = 0.
    pbar = tqdm()
    F = open('evaluate_pred.json', 'w')
    for i, d in enumerate(data):
        question = d['question']
        table = tables[d['table_id']]
        R = nl2sql(question, table)
        right += float(is_equal(R, d['sql']))
        pbar.update(1)
        pbar.set_description('< acc: %.5f >' % (right / (i + 1)))
        d['sql_pred'] = R
        s = json.dumps(d, ensure_ascii=False, indent=4)
        F.write(s.encode('utf-8') + '\n')
    F.close()
    pbar.close()
    return right / len(data)
def test(data, tables, outfile='result.json'):
    pbar = tqdm()
    F = open(outfile, 'w')
    for i, d in enumerate(data):
        question = d['question']
        table = tables[d['table_id']]
        R = nl2sql(question, table)
        pbar.update(1)
        s = json.dumps(R, ensure_ascii=False)
        F.write(s.encode('utf-8') + '\n')
    F.close()
    pbar.close()
class Evaluate(Callback):
    def __init__(self):
        self.accs = []
        self.best = 0.
        self.passed = 0
        self.stage = 0
    def on_batch_begin(self, batch, logs=None):
        """第一个epoch用来warmup,第二个epoch把学习率降到最低
        """
        if self.passed < self.params['steps']:
            lr = (self.passed + 1.) / self.params['steps'] * learning_rate
            K.set_value(self.model.optimizer.lr, lr)
            self.passed += 1
        elif self.params['steps'] <= self.passed < self.params['steps'] * 2:
            lr = (2 - (self.passed + 1.) / self.params['steps']) * (learning_rate - min_learning_rate)
            lr += min_learning_rate
            K.set_value(self.model.optimizer.lr, lr)
            self.passed += 1
    def on_epoch_end(self, epoch, logs=None):
        acc = self.evaluate()
        self.accs.append(acc)
        if acc > self.best:
            self.best = acc
            train_model.save_weights('best_model.weights')
        print ('acc: %.5f, best acc: %.5f\n' % (acc, self.best))
    def evaluate(self):
        return evaluate(valid_data, valid_tables)
train_D = data_generator(train_data, train_tables)
evaluator = Evaluate()
if __name__ == '__main__':
    train_model.fit_generator(
        train_D.__iter__(),
        steps_per_epoch=len(train_D),
        epochs=15,
        callbacks=[evaluator]
    )
else:
    train_model.load_weights('best_model.weights')
发布了35 篇原创文章 · 获赞 3 · 访问量 2488

猜你喜欢

转载自blog.csdn.net/Smile_mingm/article/details/104702485