LSTM内部结构及其实现

原理参考:https://blog.csdn.net/banxin1995/article/details/85332465

代码:

    with tf.variable_scope('lstm_nn', initializer = lstm_init):
        """
        cells = []
        for i in range(hps.num_lstm_layers):
            cell = tf.contrib.rnn.BasicLSTMCell(
                hps.num_lstm_nodes[i],
                state_is_tuple = True)
            cell = tf.contrib.rnn.DropoutWrapper(
                cell,
                output_keep_prob = keep_prob)
            cells.append(cell)
        cell = tf.contrib.rnn.MultiRNNCell(cells)
        
        initial_state = cell.zero_state(batch_size, tf.float32)
        # rnn_outputs: [batch_size, num_timesteps, lstm_outputs[-1]]
        rnn_outputs, _ = tf.nn.dynamic_rnn(
            cell, embed_inputs, initial_state = initial_state)
        last = rnn_outputs[:, -1, :]
        """
        with tf.variable_scope('inputs'):
            ix, ih, ib = _generate_params_for_lstm_cell(
                x_size = [hps.num_embedding_size, hps.num_lstm_nodes[0]],
                h_size = [hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
                bias_size = [1, hps.num_lstm_nodes[0]]
            )
        with tf.variable_scope('outputs'):
            ox, oh, ob = _generate_params_for_lstm_cell(
                x_size = [hps.num_embedding_size, hps.num_lstm_nodes[0]],
                h_size = [hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
                bias_size = [1, hps.num_lstm_nodes[0]]
            )
        with tf.variable_scope('forget'):
            fx, fh, fb = _generate_params_for_lstm_cell(
                x_size = [hps.num_embedding_size, hps.num_lstm_nodes[0]],
                h_size = [hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
                bias_size = [1, hps.num_lstm_nodes[0]]
            )
        with tf.variable_scope('memory'):
            cx, ch, cb = _generate_params_for_lstm_cell(
                x_size = [hps.num_embedding_size, hps.num_lstm_nodes[0]],
                h_size = [hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
                bias_size = [1, hps.num_lstm_nodes[0]]
            )
        state = tf.Variable(
            tf.zeros([batch_size, hps.num_lstm_nodes[0]]),
            trainable = False
        )
        h = tf.Variable(
            tf.zeros([batch_size, hps.num_lstm_nodes[0]]),
            trainable = False
        )
        
        for i in range(num_timesteps):
            # [batch_size, 1, embed_size]
            embed_input = embed_inputs[:, i, :]
            embed_input = tf.reshape(embed_input,
                                     [batch_size, hps.num_embedding_size])
            forget_gate = tf.sigmoid(
                tf.matmul(embed_input, fx) + tf.matmul(h, fh) + fb)
            input_gate = tf.sigmoid(
                tf.matmul(embed_input, ix) + tf.matmul(h, ih) + ib)
            output_gate = tf.sigmoid(
                tf.matmul(embed_input, ox) + tf.matmul(h, oh) + ob)
            mid_state = tf.tanh(
                tf.matmul(embed_input, cx) + tf.matmul(h, ch) + cb)
            state = mid_state * input_gate + state * forget_gate
            h = output_gate * tf.tanh(state)
        last = h
    
    fc_init = tf.uniform_unit_scaling_initializer(factor=1.0)

猜你喜欢

转载自blog.csdn.net/qq_36309884/article/details/88908030