本文转载于使用Torch nngraph实现LSTM,原作者将代码和理论紧密结合,非常易于理解,故分享给大家。
LSTM介绍
定义:LSTM(Long-Short Term Memory,LSTM)
是一种时间递归神经网络,论文首次发表于1997年。由于独特的设计结构,LSTM适合于处理和预测时间序列中间隔和延迟非常长的重要事件。——百度百科
下面是关于LSTM的公式:
Torch nngraph
nngraph可以方便设计一个神经网络模块。我们先使用nngraph创建一个简单的网络模块:
z=x1+x2⊙linear(x3)z=x1+x2⊙linear(x3)
可以看出这个模块的输入一共有三个,x1x1,x2x2和x3x3,输出是zz。下面是实现这个模块的torch代码:
require 'nngraph'
x1=nn.Identity()()
x2=nn.Identity()()
x3=nn.Identity()()
L=nn.CAddTable()({x1,nn.CMulTable()({x2,nn.Linear(20,10)(x3)})})
mlp=nn.gModule({x1,x2,x3},{L})
首先我们定义x1x1,x2x2和x3x3,使用nn.Identity()();然后对于linear(x3)linear(x3),我们使用x4=nn.Linear(20,10)(x3),定义了一个输入层有20个神经元,输出层有10个神经元的线性神经网络;对于x2⊙linear(x3)x2⊙linear(x3),使用x5=nn.CMulTable()(x2,x4);对于x1+x2⊙linear(x3)x1+x2⊙linear(x3),我们使用nn.CAddTable()(x1,x5) 实现;最后使用nn.gModule({input},{output})来定义神经网络模块。
我们使用forward方法测试我们的Module是否正确:
h1=torch.Tensor{1,2,3,4,5,6,7,8,9,10}
h2=torch.Tensor(10):fill(1)
h3=torch.Tensor(20):fill(2)
b=mlp:forward({h1,h2,h3})
parameters=mlp:parameters()[1]
bias=mlp:parameters()[2]
result=torch.cmul(h2,(parameters*h3+bias))+h1
首先我们定义三个输入h1h1,h2h2 和 h3h3,然后调用模块mpl的forward命令得到输出b,然后我们获取网络权重w和bias分别保存在parameters和bias变量中,计算z=h1+h2⊙linear(h3)z=h1+h2⊙linear(h3)的结果result=torch.cmul(h2,(parameters*h3+bias))+h1,最后比较result和b是否一致,我们发现计算的结果是一样的,说明我们的模块是正确的。
使用nngraph编写LSTM模块
现在我们使用nngraph写前文所描述的LSTM模块,代码如下:
require 'nngraph'
function lstm(xt,prev_c,prev_h)
function new_input_sum()
local i2h=nn.Linear(400,400)
local h2h=nn.Linear(400,400)
return nn.CAddTable()({i2h(xt),h2h(prev_h)})
end
local input_gate=nn.Sigmoid()(new_input_sum())
local forget_gate=nn.Sigmoid()(new_input_sum())
local output_gate=nn.Sigmoid()(new_input_sum())
local gt=nn.Tanh()(new_input_sum())
local ct=nn.CAddTable()({nn.CMulTable()({forget_gate,prev_c}),nn.CMulTable()({input_gate,gt})})
local ht=nn.CMulTable()({output_gate,nn.Tanh()(ct)})
return ct,ht
end
xt=nn.Identity()()
prev_c=nn.Identity()()
prev_h=nn.Identity()()
lstm=nn.gModule({xt,prev_c,prev_h},{lstm(xt,prev_c,prev_h)})
其中xt和prev_h是输入,prev_c是cell state,然后我们按照前文的公式一次计算,最后输出ct(new cell state),ht(输出)。代码的计算顺序与上文完全一致,所以这里就不再一一解释了。