03- Реализация LSTM с нуля

Введение в алгоритмы долговременной кратковременной памяти

Алгоритм LSTM (Long Short-Term Memory) — это обычный алгоритм рекуррентной нейронной сети (RNN), используемый для обработки данных последовательности, и он очень хорошо работает при обработке данных временных рядов.

Основная идея алгоритма LSTM — добавление блока памяти на основе RNN, который может помочь сети запоминать прошлые состояния и обновлять их при необходимости. В то же время LSTM также контролирует поток информации через три шлюза: ворота забывания, входные ворота и выходные ворота . Эти ворота позволяют сети выбирать, когда забывать старые состояния, принимать новые входные данные и выводить текущее состояние.

Вторая реализация алгоритма

2.1 Пакет руководств

# 从零实现
import torch
from torch import nn
import dltools

2.2 Импорт данных обучения

batch_size, num_steps = 32, 35
train_iter, vocab = dltools.load_data_time_machine(batch_size, num_steps=num_steps)

2.3 Инициализация параметров модели

# 初始化模型参数
def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size
    
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    def three():
        return (normal((num_inputs, num_hiddens)),
               normal((num_hiddens, num_hiddens)),
               torch.zeros(num_hiddens, device=device))
    W_xi, W_hi, b_i = three()  # 输入参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    
    # 输出层
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    
    # 附加梯度
    params = [W_xi, W_hi, b_i,
              W_xf, W_hf, b_f,
              W_xo, W_ho, b_o,
              W_xc, W_hc, b_c, W_hq, b_q]
    
    for param in params:
        param.requires_grad_(True)
    
    return params

Инициализируйте параметры модели LSTM . Сначала рассчитываются входные и выходные измерения на основе размера словаря и количества скрытых единиц. Далее определяется нормальная функция для генерации случайных чисел, подчиняющихся стандартному нормальному распределению, и умножения случайных чисел на 0,01 для управления диапазоном начальных значений параметров.

Далее определяется внутренняя функция три для генерации трех параметров. Каждый параметр представляет собой кортеж, содержащий матрицу весов и вектор смещения между входными и скрытыми единицами. Здесь используется метод инициализации Normal для генерации весовой матрицы, а вектор смещения инициализируется всеми нулями .

Затем вызовите внутреннюю функцию три, чтобы инициализировать параметры входа, забыть вентиль, выходной вентиль и ячейки памяти-кандидаты соответственно .

Далее нормальная функция используется для инициализации параметров выходного слоя, включая матрицу весов и вектор смещения скрытой единицы и вывода.

Наконец, поместите все параметры в список и установите для require_grad значение True, указывая, что необходимо вычислить градиент параметров.

Возвращает список параметров.

2.4 Инициализация скрытого состояния

# 初始化隐藏状态和记忆元
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
           torch.zeros((batch_size, num_hiddens), device=device))

Инициализируйте скрытые состояния и ячейки памяти сети долгосрочной краткосрочной памяти (LSTM). Входные параметры включают размер пакета (batch_size), количество скрытых модулей (num_hiddens) и вычислительное устройство (device).

Функция возвращает кортеж, содержащий два тензора, представляющие скрытое состояние и элемент памяти . Эти тензоры имеют размер (batch_size, num_hiddens) и инициализируются всеми нулями.

2.5 Определить структуру LSTM

# 定义 LSTM 主体结构
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i,
     W_xf, W_hf, b_f,
     W_xo, W_ho, b_o,
     W_xc, W_hc, b_c, W_hq, b_q] = params
    
    (H, C) = state
    outputs = []
    
    # 准备开始进行前向传播
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

Обучение трех моделей

3.1 Тестовый код

# 测试代码
X = torch.arange(10).reshape((2, 5))
num_hiddens = 512
net = dltools.RNNModelScratch(len(vocab), num_hiddens, dltools.try_gpu(), get_lstm_params, init_lstm_state, lstm)
state = net.begin_state(X.shape[0], dltools.try_gpu())
Y, new_state = net(X.to(dltools.try_gpu()), state)

3.2 Онлайн-обучение

# 训练和预测
vocab_size, num_hiddens, device = len(vocab), 256, dltools.try_gpu()
num_epochs, lr = 500, 1
model = dltools.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

Четыре реализации Pytorch

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = dltools.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

рекомендация

отblog.csdn.net/March_A/article/details/132841428