Pytorch中Rnn的实现(0.2.0版)


参考博客:https://zhuanlan.zhihu.com/p/32103001
最近需要自己实现基于rnn的attention机制,所以参考了一下pytorch里面对rnn实现的方法。在最新版的pytorch源码中,rnn貌似直接调用的底层接口,没有找到实现的代码。在0.2.0版本中找到了实现的代码。在pytorch的实现中,RNN的实现并没有调用RNNCell类

使用到的文件列表:

  • torch/nn/modules/rnn.py:RNNBase
  • torch/nn/utils/rnn.py:PackedSequence
  • torch/nn/modules/module.py:Module
  • torch/nn/backends/thnn.py:backend.register_function(‘RNN’, RNN)
  • torch/nn/_functions/rnn.py:RNN,AutogradRNN,Recurrent,StackedRNN,***Cell

torch/nn/modules/rnn.py

RNN的入口是在RNNBase中。后面的LSTM,RNN,GRU实现的代码都特别短,仅仅是在构造函数中传入了一个mode的参数。
36到56行代码创建RNN所需要的参数。这里num_layers是rnn层数;num_directions是一个1或者2的值,如果是单向rnn则为1,双向rnn是2。

        for layer in range(num_layers):
            for direction in range(num_directions):
                layer_input_size = input_size if layer == 0 else hidden_size * num_directions

                w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
                w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
                b_ih = Parameter(torch.Tensor(gate_size))
                b_hh = Parameter(torch.Tensor(gate_size))
                layer_params = (w_ih, w_hh, b_ih, b_hh)

                suffix = '_reverse' if direction == 1 else ''
                param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
                if bias:
                    param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
                param_names = [x.format(layer, suffix) for x in param_names]

                for name, param in zip(param_names, layer_params):
                    setattr(self, name, param)
                self._all_weights.append(param_names)

                self._param_buf_size += sum(p.numel() for p in layer_params)

接下来在forward函数中,传入两个参数input和hx。hx是隐藏层的初始变量,也可以不传入,如果不传入的话则默认初始隐藏层为全0。通常传入的input是一个tensor,也可以是一个packedsequence。当传入tensor的时候就是定长的rnn;如果传入packedsequence,则是为了实现一个batch中样本长度不一样长的情况。packedSequence的实现只有一个pass,

PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes'])


class PackedSequence(PackedSequence_):
    """Holds the data and list of batch_sizes of a packed sequence.
    All RNN modules accept packed sequences as inputs.
    Note:
        Instances of this class should never be created manually. They are meant
        to be instantiated by functions like :func:`pack_padded_sequence`.
    Attributes:
        data (Variable): Variable containing packed sequence
        batch_sizes (list[int]): list of integers holding information about
            the batch size at each sequence step
    """
    pass

如果要传入不定长的输入,需要先对样本进行打包。打包的函数和PackedSequence在同一个文件中。
比较重要的函数在149行,这里定义的func就是实际上我们用的时候调用的rnn。

        func = self._backend.RNN(
            ...
        )

这里出现了_backend,但是在这个文件中找不到_backend的定义。犹豫RNNBase继承了Module类,所以猜测_backend应该是在Module类里面的。
在Module类的定义中,真的可以找到_backend的定义。发现这个东西是thnn_backend。在thnn_backend中,找到了RNN的定义,

backend.register_function('RNN', RNN)

RNN实现

接下来就开始进入了实现RNN的核心部分。
在thnn_backend中的RNN指向了torch/nn/_functions/rnn.py的345行

def RNN(*args, **kwargs):
    def forward(input, *fargs, **fkwargs):
        if cudnn.is_acceptable(input.data):
            func = CudnnRNN(*args, **kwargs)# 使用GPU
        else:
            func = AutogradRNN(*args, **kwargs)# 不使用GPU
        return func(input, *fargs, **fkwargs)

    return forward

我们关注的是不使用GPU的版本,因为GPU的版本调用了cudnn的接口,看不到代码。
AutoGradRNN的定义在torch/nn/_functions/rnn.py中。
在函数最开始有判断使用的是RNN还是LSTM还是GRU,这里使用了LSTMCell。这里的LSTMCell并不是我们在使用pytorch时会用到的LSTMcell,而是一个函数。
在224行有一个判断batch_sizes的代码,batch_sizes如果存在的话就是不定长的输入,这里先看Recurrent,即定长的输入。

    if batch_sizes is None:
        rec_factory = Recurrent
    else:
        rec_factory = variable_recurrent_factory(batch_sizes)

Recurrent的定义:

def Recurrent(inner, reverse=False):
    def forward(input, hidden, weight):
        output = []
        steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
        for i in steps:
            hidden = inner(input[i], hidden, *weight)
            # hack to handle LSTM
            output.append(hidden[0] if isinstance(hidden, tuple) else hidden)

        if reverse:
            output.reverse()
        output = torch.cat(output, 0).view(input.size(0), *output[0].size())

        return hidden, output

    return forward

Recurrent实际上是返回一个函数,返回的函数的目的是将输入进行一次循环迭代,得到rnn的一层。
接下来使用定义好的rec_factory来构建rnn的。
在Recurrent完成之后,使用StackedRnn调用Recurrent得到的函数,StackedRNN的目的是完成多层的RNN。同样,StackenRNN也是一个返回函数的函数。
在这里插入图片描述
StackedRNN返回的函数,也就是AutoGradRNN的返回值。可以看到这个函数有三个参数,除了RNN的输入Input和hx,还有一个weight。这个权重就是在RNN的类中构造的那些参数。

发布了267 篇原创文章 · 获赞 12 · 访问量 15万+

猜你喜欢

转载自blog.csdn.net/u010734277/article/details/103847522
今日推荐