Introducción al aprendizaje profundo (58) Red neuronal recurrente: implementación simple de la red neuronal recurrente

Introducción al aprendizaje profundo (58) Red neuronal recurrente: implementación simple de la red neuronal recurrente

prefacio

El contenido principal proviene del enlace 1 del blog. Enlace 2 del blog. Espero que puedas apoyar mucho al autor.
Este artículo se usa para registros para evitar el olvido.

Redes neuronales recurrentes: implementación concisa de redes neuronales recurrentes

Libro de texto

Si bien la sección anterior fue instructiva para comprender cómo se implementan las redes neuronales recurrentes, no es conveniente. Esta sección mostrará cómo implementar el mismo modelo de lenguaje de manera más eficiente utilizando las funciones proporcionadas por la API de alto nivel del marco de aprendizaje profundo. Todavía comenzamos leyendo el conjunto de datos de Time Machine.

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

1 Definir el modelo

La API de alto nivel proporciona implementaciones de redes neuronales recurrentes. Construimos una capa de red neuronal recurrente con una sola capa oculta de 256 unidades ocultas rnn_layer. De hecho, no hemos discutido la importancia de las redes neuronales recurrentes multicapa. Ahora es suficiente entender múltiples capas como la salida de una capa de red neuronal recurrente que se utiliza como entrada de la siguiente capa de red neuronal recurrente.

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

Usamos un tensor para inicializar el estado oculto, su forma es (número de capas ocultas, tamaño del lote, número de unidades ocultas).

nn.RNNEl estado oculto devuelto por el ejemplo en el cálculo directo se refiere al estado oculto de la capa oculta en el último paso de tiempo : cuando hay varias capas de la capa oculta, el estado oculto de cada capa se registrará en esta variable; para memoria a corto plazo (LSTM), el estado oculto es una tupla (h, c), es decir, estado oculto y estado de celda. Presentaremos la memoria a corto plazo y las redes neuronales recurrentes profundas más adelante en este capítulo.

state = torch.zeros((1, batch_size, num_hiddens))
state.shape

producción:

torch.Size([1, 32, 256])

Con un estado oculto y una entrada, podemos calcular la salida con el estado oculto actualizado. Es importante enfatizar que rnn_layerla "salida" de (Y) no implica el cálculo de la capa de salida: se refiere a los estados ocultos en cada paso de tiempo que pueden usarse como entrada para las capas de salida posteriores.

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

producción:

(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

Similar a la sección anterior, definimos una RNNModelclase para un modelo de red neuronal recurrente completo. Tenga en cuenta que rnn_layersolo se incluyen capas recurrentes ocultas, también necesitamos crear una capa de salida separada.

class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
        # 它的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens),
                                device=device)
        else:
            # nn.LSTM以元组作为隐状态
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))

2 Entrenamiento y Predicción

Antes de entrenar el modelo, hagamos predicciones basadas en un modelo con pesos aleatorios.

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

Obviamente, este tipo de modelo no puede generar buenos resultados en absoluto. A continuación, usamos las llamadas de hiperparámetro definidas en la sección anterior train_ch8y usamos la API de alto nivel para entrenar el modelo.

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

producción:

perplexity 1.3, 286908.2 tokens/sec on cuda:0
time traveller came the time traveller but now you begin to spen
traveller pork acong wa canome precable thig thit lepanchat

En comparación con la sección anterior, el modelo logra una menor perplejidad en un tiempo más corto debido a más optimizaciones de código por parte de la API de alto nivel del marco de aprendizaje profundo.

3 Resumen

  • La API de alto nivel del marco de aprendizaje profundo proporciona la implementación de capas de redes neuronales recurrentes.

  • La capa de red neuronal recurrente de la API de alto nivel devuelve una salida y un estado oculto actualizado, también necesitamos calcular la capa de salida de todo el modelo.

  • El uso de una implementación de API de alto nivel acelera el entrenamiento en comparación con la implementación de redes neuronales recurrentes desde cero.

Supongo que te gusta

Origin blog.csdn.net/qq_52358603/article/details/128275418
Recomendado
Clasificación