pytorch实现RNN循环神经网络

RNN神经网络

 一、概述

循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的递归神经网络(recursive neural network),通过网络的内部结构捕捉序列之间的模式特征,一般也是以序列形式输出。

RNN(Recurrent Neural Network) 是带有循环的神经网络。

假设我们有一组向量(x^{(1)},x^{(2)},...,x^{(n)}),通过假设函数$H$解释

H^{(t)}=\sigma(W^{ht}\cdot X{(t)}+W^{hh}\cdot H^{(t-1)}+b_n)

t 代表时间量,\sigma是非线性函数,Y^{(t)}相对应的是每次循环的输出。我们往往取连续时间序列运算中的最后一次输出。

循环过程的展开,来表示RNN网络的层次深度。

RNN源于前馈神经网络,可以利用其内部状态(存储器)处理可变长的输入序列。这使得RNN更适用于未分段的、连续的手写识别或语音识别等任务。


RNN工作流程

向量 (x^{(1)},x^{(2)},...,x^{(n)})在每次循环时,导入一个元素,模型及时将内部状态H_{t-1}从单元格转移到下一个单元格。请注意,所有单元格使用相同的权重W

PyTorch中RNN

pytorch中的完整RNN通过torch.nn.RNN 类实现。

RNN — PyTorch 1.12 documentation

参数:

input_size – 输入样本x的特征数量

hidden_size – 隐藏状态的特征数

num_layers – 循环层数。例如,设置 num_layers=2 意味着将两个RNN叠在一起形成一个堆叠的RNN,第二个RNN接收第一个RNN的输出并计算最终结果。默认值为1

nonlinearity – 要使用的非线性函数。可以是"tanh""relu"。默认值: 'tanh'

bias – 如果为 False,则该层不使用偏置权重 $b_{ih}$ 和 $b_{hh}$。默认值:True

batch_first – 如果为True,则输入和输出张量将作为 (batch, seq, feature) 而不是 (seq, batch, feature) 。注意,这不适用于隐藏或单元格状态。默认值:False

dropout – 如果非零,则在除最后一层之外的每个RNN层的输出上引入一个Dropout 层,dropout 概率等于 dropout。默认值: 0

bidirectional – 如果为 True,则使用双向 RNN。默认值:False

Pytorch代码实现

class RNN_Classfication(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_class):
        super(RNN_Classfication, self).__init__()
 
        self.rnn = nn.RNN(        # RNN模型
            input_size = input_size,      # 图片每行的数据像素点(特征数)
            hidden_size = hidden_size,     # rnn 隐藏层单元数
            num_layers = 1,       # 有几层 RNN layers
            batch_first = True,   # 指定batch为第一维 e.g. (batch, time_step, input_size)
        )
        self.out = nn.Linear(hidden_size, num_class)    # 输出层
 
    def forward(self, x):
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (batch, hidden_size) rnn hidden
        r_out, h_n = self.rnn(x)   
        
        # 选取最后一个时间点的 r_out 输出
        # 这里 r_out[:, -1, :] 的值也是 h_n 的值
        out = self.out(r_out[:,-1,:])
        # out = self.out(h_n.squeeze(0))
        return out

猜你喜欢

转载自blog.csdn.net/m0_71145378/article/details/126934975