In-depth understanding of LSTM: cases and code explanations

In-depth understanding of LSTM: cases and code explanations

Introduction:
Long Short-Term Memory (LSTM) is a special type of Recurrent Neural Network (RNN) that can handle and capture long-term dependencies. This article will help readers deeply understand the working principle and application of LSTM through a specific case and detailed code explanation.

Text:
Case background:
Suppose we want to perform an emotion classification task, that is, judging whether the emotion of a given text is positive or negative. We will use the IMDB movie review dataset, which contains 25,000 movie reviews, half of which are positive and half of which are negative.

Introduction to the LSTM model:
LSTM is a special type of RNN that solves the vanishing and exploding gradient problems in RNN by using a gating mechanism. LSTM has three key gating units: input gate, forget gate and output gate, which control the flow of input, forget and output information respectively. LSTM also uses a cell state to store and transmit information about long-term dependencies.

Code implementation:
Next, we will use the PyTorch library to implement the LSTM model, and train and test it.

First, we import the required libraries and modules:

import torch
import torch.nn as nn
import torchtext
from torchtext.datasets import IMDB
from torchtext.data import Field, LabelField, BucketIterator

Then, we define the LSTM model class:

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTM, self).__init__()

        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input):
        output, _ = self.lstm(input)
        output = self.fc(output[-1, :, :])
        output = self.softmax(output)
        return output

Next, we perform data preparation and preprocessing:

# 定义字段和标签
TEXT = Field(lower=True, batch_first=True, fix_length=500)
LABEL = LabelField(dtype=torch.float)

# 加载数据集
train_data, test_data = IMDB.splits(TEXT, LABEL)

# 构建词汇表
TEXT.build_vocab(train_data, max_size=10000)
LABEL.build_vocab(train_data)

# 创建迭代器
train_iterator, test_iterator = BucketIterator.splits(
    (train_data, test_data),
    batch_size=32,
    sort_key=lambda x: len(x.text),
    repeat=False
)

Next, we define the training and testing functions and train and test the model:

# 初始化模型和优化器
input_size = len(TEXT.vocab)
hidden_size = 128
output_size = 1

lstm = LSTM(input_size, hidden_size, output_size)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(lstm.parameters())

# 训练函数
def train(model, iterator, optimizer, criterion):
    model.train()
    
    for batch in iterator:
        optimizer.zero_grad()
        
        text, text_lengths = batch.text
        predictions = model(text)
        
        loss = criterion(predictions.squeeze(), batch.label)
        loss.backward()
        optimizer.step()
        
# 测试函数
def evaluate(model, iterator, criterion):
    model.eval()
    
    total_loss = 0
    total_correct = 0
    
    with torch.no_grad():
        for batch in iterator:
            text, text_lengths = batch.text
            predictions = model(text)
            
            loss = criterion(predictions.squeeze(), batch.label)
            total_loss += loss.item()
            
            preds = torch.round(torch.sigmoid(predictions))
            total_correct += (preds == batch.label).sum().item()
    
    return total_loss / len(iterator), total_correct / len(iterator.dataset)

# 模型训练和测试
num_epochs = 10
for epoch in range(num_epochs):
    train(lstm, train_iterator, optimizer, criterion)
    test_loss, test_acc = evaluate(lstm, test_iterator, criterion)
    print(f'Epoch: {
      
      epoch+1}, Test Loss: {
      
      test_loss:.4f}, Test Acc: {
      
      test_acc:.4f}')

Conclusion:
Through the above code, we implemented a simple LSTM model and used the IMDB movie review data set for training and testing of emotion classification. Through this case and detailed code explanation, readers can have an in-depth understanding of the working principle and application of the LSTM model.

Conclusion:
LSTM is a powerful neural network model that can handle and capture long-term dependencies and is widely used in natural language processing, speech recognition, time series prediction and other fields.
references:

  1. PyTorch official documentation: https://pytorch.org/docs/stable/index.html
  2. torchtext official documentation: https://torchtext.readthedocs.io/en/latest/
  3. IMDB movie review data set: https://ai.stanford.edu/~amaas/data/sentiment/

Guess you like

Origin blog.csdn.net/qq_51447496/article/details/133337178