Tensorflow常见模型实现之一(LSTM/BiLSTM)

1. LSTM

import tensorflow as tf
import tensorflow.contrib as contrib
from tensorflow.python.ops import array_ops

class lstm(object):
    def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):
        self.in_data = in_data
        self.hidden_dim = hidden_dim
        self.batch_seqlen = batch_seqlen
        self.flag = flag
        
        lstm_cell = contrib.rnn.LSTMCell(self.hidden_dim)
        out, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=self.in_data, sequence_length=self.batch_seqlen,dtype=tf.float32)

        if flag=='all_ht':
            self.out = out
        if flag = 'first_ht':
            self.out = out[:,0,:]
        if flag = 'last_ht':
            self.out = out[:,-1,:]
        if flag = 'concat':
            self.out = tf.concat([out[:,0,:], out[:,-1,:]],1)

2. Bi-LSTM

import tensorflow as tf
import tensorflow.contrib as contrib
from tensorflow.python.ops import array_ops
from tensorflow.python.framework import dtypes

class bilstm(object):
    def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):
        self.in_data = in_data
        self.hidden_dim = hidden_dim
        self.batch_seqlen = batch_seqlen
        self.flag = flag

        lstm_cell_fw = contrib.rnn.LSTMCell(self.hidden_dim)
        lstm_cell_bw = contrib.rnn.LSTMCell(self.hidden_dim)
        out, state = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_cell_fw,cell_bw=lstm_cell_bw,inputs=self.in_data, sequence_lenth=self.batch_seqlen,dtype=tf.float32)
        bi_out = tf.concat(out, 2)
        if flag=='all_ht':
            self.out = bi_out
        if flag=='first_ht':
            self.out = bi_out[:,0,:]
        if flag=='last_ht':
            self.out = tf.concat([state[0].h,state[1].h], 1)
        if flag=='concat':
            self.out = tf.concat([bi_out[:,0,:],tf.concat([state[0].h,state[1].h], 1)],1)

 

猜你喜欢

转载自blog.csdn.net/u011195431/article/details/83041882