炼丹技巧 | TensorFlow模型线上私密部署

前情回顾

我们之前介绍了

BERT的原理与应用

BERT与其他预训练模型

BERT四大下游任务

炼丹技巧 | BERT的下接结构调参

现在我们基于(2019BDCI互联网金融新实体发现 | 思路与代码框架分享(单模第一,综合第二))代码实践来介绍一下模型在线上部署的时候不暴露我们的model结构代码,也就是将自己的模型匿名,但却可以在线上持续提供服务。

具体项目代码链接:

https://github.com/ChileWang0228/Deep-Learning-With-Python/tree/master/chapter8

首先看一下我们的模型的预测代码:predict.py,我们模型要上线服务,predict.py是必须的,所以要想加密我们的模型结构代码,也只能从predict.py下手。我们可以看到line 4~line 14是读取我们模型预测所需要的变量,之后就可以直接预测了,并不需要重新使用model.py重新构图去读。

那么这是怎么完成的呢?

原因就在我们在model.py设置我们预测所需要的变量名,到预测的时候(predict)直接将模型读进来,然后抽取所需的变量来预测即可,避免暴露模型结构的代码。

# predict.py




1.  def get_session(checkpoint_path):  
2.    # 隐藏不重要其他代码 
3.    # 读取模型变量 
4.    _input_x = graph.get_operation_by_name("input_x_word").outputs[0]  
5.    _input_x_len = graph.get_operation_by_name("input_x_len").outputs[0]  
6.    _input_mask = graph.get_operation_by_name("input_mask").outputs[0]  
7.    _input_relation = graph.get_operation_by_name("input_relation").outputs[0]
8.    _keep_ratio = graph.get_operation_by_name('dropout_keep_prob').outputs[0]
9.    _is_training = graph.get_operation_by_name('is_training').outputs[0]  
10.    used = tf.sign(tf.abs(_input_x))  
11.    length = tf.reduce_sum(used, reduction_indices=1)  
12.    lengths = tf.cast(length, tf.int32)  
13.    logits = graph.get_operation_by_name('project/pred_logits').outputs[0]  
14.    trans = graph.get_operation_by_name('transitions').outputs[0]  
15.    
16.    def run_predict(feed_dict):  
17.        return session.run([logits, lengths, trans], feed_dict)  
18.    print('recover from: {}'.format(checkpoint_path))  
19.    return run_predict, (_input_x, _input_x_len, _input_mask, _input_relation, _keep_ratio, _is_training)

我们可以看到predict.py中的line 4 ~ line 14是我们所需的变量。

对应到model.py,我们分别给他们附上变量名(通过name_scope、name赋值)即可,代码如下。

# model.py




class Model:
    def __init__(self, config):
        self.config = config
        # 喂入模型的数据占位符
        self.input_x_word = tf.placeholder(tf.int32, [None, None], name="input_x_word")
        self.input_x_len = tf.placeholder(tf.int32, name='input_x_len')
        self.input_mask = tf.placeholder(tf.int32, [None, None], name='input_mask')
        self.input_relation = tf.placeholder(tf.int32, [None, None], name='input_relation')  # 实体NER的真实标签
        self.keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob')
        self.is_training = tf.placeholder(tf.bool, None, name='is_training')




       # 隐藏不重要代码




       # CRF超参数
        used = tf.sign(tf.abs(self.input_x_word))
        length = tf.reduce_sum(used, reduction_indices=1)
        self.lengths = tf.cast(length, tf.int32)
      
    # 隐藏不重要代码




    def project_layer(self, lstm_outputs, name=None):
        """
        hidden layer between lstm layer and logits
        :param lstm_outputs: [batch_size, num_steps, emb_size]
        :return: [batch_size, num_steps, num_tags]
        """




        with tf.name_scope("project" if not name else name):
        # 隐藏不重要代码
            # project to score of tags
            with tf.name_scope("logits"):
                W = tf.get_variable("LW", shape=[self.lstm_dim, self.relation_num],
                                    dtype=tf.float32, initializer=self.initializer)




                b = tf.get_variable("Lb", shape=[self.relation_num], dtype=tf.float32,
                                    initializer=tf.zeros_initializer())




                pred = tf.nn.xw_plus_b(hidden, W, b)




 def loss_layer(self, project_logits, lengths, name=None):
        """
        计算CRF的loss
        :param project_logits: [1, num_steps, num_tags]
        :return: scalar loss
        """
        with tf.name_scope("crf_loss" if not name else name):
           # 隐藏不重要代码
            self.trans = tf.get_variable(
                name="transitions",
                shape=[self.relation_num + 1, self.relation_num + 1],  # 1
                # shape=[self.relation_num, self.relation_num],  # 1
                initializer=self.initializer)

   

笔者将在下期介绍

代码实践|LSTM实例之作诗机器人

敬请期待~

 

关注我的微信公众号~不定期更新相关专业知识~

内容 |阿力阿哩哩 

编辑 | 阿璃 

点个“在看”,作者高产似那啥~

发布了76 篇原创文章 · 获赞 5 · 访问量 6218

猜你喜欢

转载自blog.csdn.net/Chile_Wang/article/details/104386358
今日推荐