A basic BERT model framework

Building and training a complete BERT model is a complex and time-consuming task. The BERT model consists of multiple components, including embedding layers, Transformer encoders, and classifiers. Writing the complete code for these components is beyond the scope of this text. However, a basic BERT model framework is provided in order to understand its structure and the setup of its main components.

import torch
import torch.nn as nn

# BERT Model
class BERTModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, max_seq_length, num_classes):
        super(BERTModel, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding = nn.Embedding(max_seq_length, embedding_dim)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embedding_dim, hidden_dim, num_heads)
            for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(embedding_dim, num_classes)
        self.dropout = nn.Dropout(p=0.1)
        
    def forward(self, input_ids, attention_mask):
        embedded = self.embedding(input_ids)  # [batch_size, seq_length, embedding_dim]
        positions = torch.arange(0, input_ids.size(1), device=input_ids.device).unsqueeze(0).expand_as(input_ids)
        position_embedded = self.position_embedding(positions)  # [batch_size, seq_length, embedding_dim]
        encoded = self.dropout(embedded + position_embedded)  # [batch_size, seq_length, embedding_dim]
        
        for transformer_block in self.transformer_blocks:
            encoded = transformer_block(encoded, attention_mask)
        
        pooled_output = encoded[:, 0, :]  # [batch_size, embedding_dim]
        logits = self.classifier(pooled_output)  # [batch_size, num_classes]
        return logits


# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, num_heads):
        super(TransformerBlock, self).__init__()
        
        self.attention = MultiHeadAttention(embedding_dim, num_heads)
        self.feed_forward = FeedForward(hidden_dim, embedding_dim)
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)
        
    def forward(self, x, attention_mask):
        attended = self.attention(x, x, x, attention_mask)  # [batch_size, seq_length, embedding_dim]
        residual1 = x + attended
        normalized1 = self.layer_norm1(residual1)  # [batch_size, seq_length, embedding_dim]
        
        fed_forward = self.feed_forward(normalized1)  # [batch_size, seq_length, embedding_dim]
        residual2 = normalized1 + fed_forward
        normalized2 = self.layer_norm2(residual2)  # [batch_size, seq_length, embedding_dim]
        
        return normalized2


# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads
        
        self.q_linear = nn.Linear(embedding_dim, embedding_dim)
        self.k_linear = nn.Linear(embedding_dim, embedding_dim)
        self.v_linear = nn.Linear(embedding_dim, embedding_dim)
        self.out_linear = nn.Linear(embedding_dim, embedding_dim)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        query = self.q_linear(query)  # [batch_size, seq_length, embedding_dim]
        key = self.k_linear(key)  # [batch_size, seq_length, embedding_dim]
        value = self.v_linear(value)  # [batch_size, seq_length, embedding_dim]
        
        query = self._split_heads(query)  # [batch_size, num_heads, seq_length, head_dim]
        key = self._split_heads(key)  # [batch_size, num_heads, seq_length, head_dim]
        value = self._split_heads(value)  # [batch_size, num_heads, seq_length, head_dim]
        
        scores = torch.matmul(query, key.transpose(-1, -2))  # [batch_size, num_heads, seq_length, seq_length]
        scores = scores / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32, device=scores.device))
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2), -1e9)
        
        attention_outputs = torch.softmax(scores, dim=-1)  # [batch_size, num_heads, seq_length, seq_length]
        attention_outputs = self.dropout(attention_outputs)
        
        attended = torch.matmul(attention_outputs, value)  # [batch_size, num_heads, seq_length, head_dim]
        attended = attended.transpose(1, 2).contiguous()  # [batch_size, seq_length, num_heads, head_dim]
        attended = attended.view(batch_size, -1, self.embedding_dim)  # [batch_size, seq_length, embedding_dim]
        attended = self.out_linear(attended)  # [batch_size, seq_length, embedding_dim]
        
        return attended
        
    def _split_heads(self, x):
        batch_size, seq_length, embedding_dim = x.size()
        x = x.view(batch_size, seq_length, self.num_heads, self.head_dim)
        x = x.transpose(1, 2).contiguous()
        return x


# Feed Forward
class FeedForward(nn.Module):
    def __init__(self, hidden_dim, embedding_dim):
        super(FeedForward, self).__init__()
        
        self.linear1 = nn.Linear(embedding_dim, hidden_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(p=0.1)
        self.linear2 = nn.Linear(hidden_dim, embedding_dim)
        
    def forward(self, x):
        x = self.linear1(x)  # [batch_size, seq_length, hidden_dim]
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)  # [batch_size, seq_length, embedding_dim]
        return x


# Example usage
vocab_size = 10000
embedding_dim = 300
hidden_dim = 768
num_layers = 12
num_heads = 12
max_seq_length = 512
num_classes = 2

model = BERTModel(vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, max_seq_length, num_classes)
input_ids = torch.tensor([[1, 2, 3, 4, 5]]).long()
attention_mask = torch.tensor([[1, 1, 1, 1, 1]]).long()
logits = model(input_ids, attention_mask)
print(logits.shape)  # [1, num_classes]

This code gives a basic BERT model structure and includes components such as Transformer block, attention mechanism and feed-forward neural network. You need to adjust the parameters and model structure according to your needs and data set.

Please note that this is just a simplified version, the real BERT model also includes pre-training tasks such as Masked Language Modeling (MLM) and Next Sentence Prediction (NSP). In addition, data preprocessing, definition of loss function and training loop are also required. In actual environments, it is strongly recommended to use BERT models that have been extensively pre-trained, such as the pre-trained models in Hugging Face's transformers library, to obtain better performance results.

Guess you like

Origin blog.csdn.net/Metal1/article/details/132890889