好文!使用Torch nngraph实现LSTM

本文转载于使用Torch nngraph实现LSTM,原作者将代码和理论紧密结合,非常易于理解,故分享给大家。

LSTM介绍

定义:LSTM(Long-Short Term Memory,LSTM) 
是一种时间递归神经网络,论文首次发表于1997年。由于独特的设计结构,LSTM适合于处理和预测时间序列中间隔和延迟非常长的重要事件。——百度百科

下面是关于LSTM的公式: 

Torch nngraph 

nngraph可以方便设计一个神经网络模块。我们先使用nngraph创建一个简单的网络模块: 

z=x1+x2⊙linear(x3)z=x1+x2⊙linear(x3) 
可以看出这个模块的输入一共有三个,x1x1,x2x2和x3x3,输出是zz。下面是实现这个模块的torch代码:

require 'nngraph'
x1=nn.Identity()()
x2=nn.Identity()()
x3=nn.Identity()()
L=nn.CAddTable()({x1,nn.CMulTable()({x2,nn.Linear(20,10)(x3)})})
mlp=nn.gModule({x1,x2,x3},{L})

首先我们定义x1x1,x2x2和x3x3,使用nn.Identity()();然后对于linear(x3)linear(x3),我们使用x4=nn.Linear(20,10)(x3),定义了一个输入层有20个神经元,输出层有10个神经元的线性神经网络;对于x2⊙linear(x3)x2⊙linear(x3),使用x5=nn.CMulTable()(x2,x4);对于x1+x2⊙linear(x3)x1+x2⊙linear(x3),我们使用nn.CAddTable()(x1,x5) 实现;最后使用nn.gModule({input},{output})来定义神经网络模块。 
我们使用forward方法测试我们的Module是否正确: 

h1=torch.Tensor{1,2,3,4,5,6,7,8,9,10}
h2=torch.Tensor(10):fill(1)
h3=torch.Tensor(20):fill(2)
b=mlp:forward({h1,h2,h3})
parameters=mlp:parameters()[1]
bias=mlp:parameters()[2]
result=torch.cmul(h2,(parameters*h3+bias))+h1

首先我们定义三个输入h1h1,h2h2 和 h3h3,然后调用模块mpl的forward命令得到输出b,然后我们获取网络权重w和bias分别保存在parameters和bias变量中,计算z=h1+h2⊙linear(h3)z=h1+h2⊙linear(h3)的结果result=torch.cmul(h2,(parameters*h3+bias))+h1,最后比较result和b是否一致,我们发现计算的结果是一样的,说明我们的模块是正确的。 

使用nngraph编写LSTM模块 

现在我们使用nngraph写前文所描述的LSTM模块,代码如下:

require 'nngraph'
function lstm(xt,prev_c,prev_h)
    function new_input_sum()
        local i2h=nn.Linear(400,400)
        local h2h=nn.Linear(400,400)
        return nn.CAddTable()({i2h(xt),h2h(prev_h)})
    end
    local input_gate=nn.Sigmoid()(new_input_sum())
    local forget_gate=nn.Sigmoid()(new_input_sum())
    local output_gate=nn.Sigmoid()(new_input_sum())
    local gt=nn.Tanh()(new_input_sum())
    local ct=nn.CAddTable()({nn.CMulTable()({forget_gate,prev_c}),nn.CMulTable()({input_gate,gt})})
    local ht=nn.CMulTable()({output_gate,nn.Tanh()(ct)})
    return ct,ht
end
xt=nn.Identity()()
prev_c=nn.Identity()()
prev_h=nn.Identity()()
lstm=nn.gModule({xt,prev_c,prev_h},{lstm(xt,prev_c,prev_h)})

 其中xtprev_h是输入,prev_c是cell state,然后我们按照前文的公式一次计算,最后输出ct(new cell state),ht(输出)。代码的计算顺序与上文完全一致,所以这里就不再一一解释了。

猜你喜欢

转载自blog.csdn.net/tiaojingtao1293/article/details/81207532
今日推荐