(十二)长短期记忆(LSTM)

1、介绍

长短期记忆修改了循环神经网络隐藏状态的计算方式,并引入了与隐藏状态形状相同的记忆细胞(某些文献把记忆细胞当成一种特殊的隐藏状态)。

2、具体的设计

(1)输入门、遗忘门和输出门

I t = σ ( X t W x i + H t 1 W h i + b i ) , F t = σ ( X t W x f + H t 1 W h f + b f ) , O t = σ ( X t W x o + H t 1 W h o + b o ) .

和门控循环单元中的重置门和更新门一样,这里的输入门、遗忘门和输出门中每个元素的值域都是[0,1]( 这里的 i , f , o 大小都是 h ,即超参数只有num_hiddens)

(2)候选记忆细胞

C ~ t = tanh ( X t W x c + H t 1 W h c + b c ) .

(3)记忆细胞

C t = F t C t 1 + I t C ~ t .

如果遗忘门一直近似1且输入门一直近似0,过去的记忆细胞将一直通过时间保存并传递至当前时间步。 这个设计可以应对循环神经网络中的梯度衰减问题,并更好地捕捉时序数据中间隔较大的依赖关系。

(4)隐藏状态

通过输出门来控制从记忆细胞到隐藏状态

H t = O t tanh ( C t ) .

(5)输出层

在时间步t,长短期记忆的输出层计算和之前描述的循环神经网络输出层计算一样

3、代码

超参数num_hiddens定义了隐藏单元的个数

扫描二维码关注公众号,回复: 1920393 查看本文章
ctx = gb.try_gpu()
input_dim = vocab_size
num_hiddens = 256
output_dim = vocab_size

def get_params():
    # 输入门参数.
    W_xi = nd.random_normal(scale=0.01, shape=(input_dim, num_hiddens),
                            ctx=ctx)
    W_hi = nd.random_normal(scale=0.01, shape=(num_hiddens, num_hiddens),
                            ctx=ctx)
    b_i = nd.zeros(num_hiddens, ctx=ctx)
    # 遗忘门参数。
    W_xf = nd.random_normal(scale=0.01, shape=(input_dim, num_hiddens),
                            ctx=ctx)
    W_hf = nd.random_normal(scale=0.01, shape=(num_hiddens, num_hiddens),
                            ctx=ctx)
    b_f = nd.zeros(num_hiddens, ctx=ctx)
    # 输出门参数。
    W_xo = nd.random_normal(scale=0.01, shape=(input_dim, num_hiddens),
                            ctx=ctx)
    W_ho = nd.random_normal(scale=0.01, shape=(num_hiddens, num_hiddens),
                            ctx=ctx)
    b_o = nd.zeros(num_hiddens, ctx=ctx)
    # 候选细胞参数。
    W_xc = nd.random_normal(scale=0.01, shape=(input_dim, num_hiddens),
                            ctx=ctx)
    W_hc = nd.random_normal(scale=0.01, shape=(num_hiddens, num_hiddens),
                            ctx=ctx)
    b_c = nd.zeros(num_hiddens, ctx=ctx)
    # 输出层参数。
    W_hy = nd.random_normal(scale=0.01, shape=(num_hiddens, output_dim),
                            ctx=ctx)
    b_y = nd.zeros(output_dim, ctx=ctx)

    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hy, b_y]
    for param in params:
        param.attach_grad()
    return params



def lstm_rnn(inputs, state_h, state_c, *params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hy, b_y] = params
    H = state_h
    C = state_c
    outputs = []
    for X in inputs:
        I = nd.sigmoid(nd.dot(X, W_xi) + nd.dot(H, W_hi) + b_i)
        F = nd.sigmoid(nd.dot(X, W_xf) + nd.dot(H, W_hf) + b_f)
        O = nd.sigmoid(nd.dot(X, W_xo) + nd.dot(H, W_ho) + b_o)
        C_tilda = nd.tanh(nd.dot(X, W_xc) + nd.dot(H, W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * C.tanh()
        Y = nd.dot(H, W_hy) + b_y
        outputs.append(Y)
    return (outputs, H, C)

猜你喜欢

转载自blog.csdn.net/hao5335156/article/details/80635282